def testComponentSpecWithRuntimeParam(self): param = data_types.RuntimeParameter(name='split-1', ptype=Text) serialized_param = str(param) # Dict representation of a example_gen_pb2.Input proto message. proto = dict(splits=[ dict(name=param, pattern='pattern1'), dict(name='name2', pattern='pattern2'), dict(name='name3', pattern='pattern3'), ]) input_channel = Channel(type=_InputArtifact) output_channel = Channel(type=_OutputArtifact) spec = _BasicComponentSpec(folds=10, proto=proto, input=input_channel, output=output_channel) # Verify proto property. self.assertIsInstance(spec.exec_properties['proto'], str) decoded_proto = json.loads(spec.exec_properties['proto']) self.assertCountEqual(['splits'], decoded_proto.keys()) self.assertEqual(3, len(decoded_proto['splits'])) self.assertCountEqual([serialized_param, 'name2', 'name3'], list(s['name'] for s in decoded_proto['splits'])) self.assertCountEqual(['pattern1', 'pattern2', 'pattern3'], list(s['pattern'] for s in decoded_proto['splits']))
def testEncodeWithKeys(self): channel = Channel(type=_MyType) channel_future = channel.future()[0].value actual_pb = channel_future.encode_with_keys( lambda channel: channel.type_name) expected_pb = text_format.Parse( """ operator { artifact_value_op { expression { operator { index_op { expression { placeholder { key: "MyTypeName" } } } } } } } """, placeholder_pb2.PlaceholderExpression()) self.assertProtoEquals(actual_pb, expected_pb) self.assertIsNone(channel_future._key)
def testComponentspecBasic(self): proto = example_gen_pb2.Input() proto.splits.extend([ example_gen_pb2.Input.Split(name='name1', pattern='pattern1'), example_gen_pb2.Input.Split(name='name2', pattern='pattern2'), example_gen_pb2.Input.Split(name='name3', pattern='pattern3'), ]) input_channel = Channel(type_name='InputType') output_channel = Channel(type_name='OutputType') spec = _BasicComponentSpec(folds=10, proto=proto, input=input_channel, output=output_channel) # Verify proto property. self.assertIsInstance(spec.exec_properties['proto'], str) decoded_proto = json.loads(spec.exec_properties['proto']) self.assertCountEqual(['splits'], decoded_proto.keys()) self.assertEqual(3, len(decoded_proto['splits'])) self.assertCountEqual(['name1', 'name2', 'name3'], list(s['name'] for s in decoded_proto['splits'])) self.assertCountEqual(['pattern1', 'pattern2', 'pattern3'], list(s['pattern'] for s in decoded_proto['splits'])) # Verify other properties. self.assertEqual(10, spec.exec_properties['folds']) self.assertIs(spec.inputs['input'], input_channel) self.assertIs(spec.outputs['output'], output_channel) # Verify compatibility aliasing behavior. self.assertIs(spec.inputs['future_input_name'], spec.inputs['input']) self.assertIs(spec.outputs['future_output_name'], spec.outputs['output']) with self.assertRaisesRegexp( TypeError, "Expected type <(class|type) 'int'> for parameter u?'folds' but got " 'string.'): spec = _BasicComponentSpec(folds='string', input=input_channel, output=output_channel) with self.assertRaisesRegexp( TypeError, '.*should be a Channel of .*InputType.*got (.|\\s)*WrongType.*' ): spec = _BasicComponentSpec(folds=10, input=Channel(type_name='WrongType'), output=output_channel) with self.assertRaisesRegexp( TypeError, '.*should be a Channel of .*OutputType.*got (.|\\s)*WrongType.*' ): spec = _BasicComponentSpec(folds=10, input=input_channel, output=Channel(type_name='WrongType'))
def testJsonRoundTripUnknownArtifactClass(self): channel = Channel(type=_MyType) serialized = channel.to_json_dict() serialized['type']['name'] = 'UnknownTypeName' rehydrated = Channel.from_json_dict(serialized) self.assertEqual('UnknownTypeName', rehydrated.type_name) self.assertEqual(channel.type._get_artifact_type().properties, rehydrated.type._get_artifact_type().properties) self.assertTrue(rehydrated.type._AUTOGENERATED)
def testProtoFutureValueOperator(self): test_pb_filepath = os.path.join( os.path.dirname(__file__), 'testdata', 'proto_placeholder_future_value_operator.pbtxt') with open(test_pb_filepath) as text_pb_file: expected_pb = text_format.ParseLines( text_pb_file, placeholder_pb2.PlaceholderExpression()) output_channel = Channel(type=standard_artifacts.Integer) placeholder = output_channel.future()[0].value placeholder._key = '_component.num' self.assertProtoEquals(placeholder.encode(), expected_pb)
def test_channel_utils_as_channel_success(self): instance_a = Artifact('MyTypeName') instance_b = Artifact('MyTypeName') chnl_original = Channel('MyTypeName', artifacts=[instance_a, instance_b]) chnl_result = channel_utils.as_channel(chnl_original) self.assertEqual(chnl_original, chnl_result)
def _render_channel_as_mdstr(input_channel: channel.Channel) -> Text: """Render a Channel as markdown string with the following format. **Type**: input_channel.type_name **Artifact: artifact1** **Properties**: **key1**: value1 **key2**: value2 ...... Args: input_channel: the channel to be rendered. Returns: a md-formatted string representation of the channel. """ md_str = '**Type**: {}\n\n'.format( _sanitize_underscore(input_channel.type_name)) rendered_artifacts = [] # List all artifacts in the channel. for single_artifact in input_channel.get(): rendered_artifacts.append(_render_artifact_as_mdstr(single_artifact)) return md_str + '\n\n'.join(rendered_artifacts)
def testDumpUiMetadata(self): trainer = Trainer(examples=Channel(type=standard_artifacts.Examples), module_file='module_file', train_args=trainer_pb2.TrainArgs(splits=['train'], num_steps=100), eval_args=trainer_pb2.EvalArgs(splits=['eval'], num_steps=50)) model_run = standard_artifacts.ModelRun() model_run.uri = 'model_run_uri' exec_info = data_types.ExecutionInfo( input_dict={}, output_dict={'model_run': [model_run]}, exec_properties={}, execution_id='id') ui_metadata_path = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName, 'json') fileio.makedirs(os.path.dirname(ui_metadata_path)) container_entrypoint._dump_ui_metadata(trainer, exec_info, ui_metadata_path) with open(ui_metadata_path) as f: ui_metadata = json.load(f) self.assertEqual('tensorboard', ui_metadata['outputs'][-1]['type']) self.assertEqual('model_run_uri', ui_metadata['outputs'][-1]['source'])
def as_channel(source: Union[Channel, Iterable[Artifact]]) -> Channel: """Converts artifact collection of the same artifact type into a Channel. Args: source: Either a Channel or an iterable of Artifact. Returns: A static Channel containing the source artifact collection. Raises: ValueError when source is not a non-empty iterable of Artifact. """ if isinstance(source, Channel): return source elif isinstance(source, collections.Iterable): try: first_element = next(iter(source)) if isinstance(first_element, Artifact): return Channel(type_name=first_element.type_name, artifacts=source) else: raise ValueError( 'Invalid source to be a channel: {}'.format(source)) except StopIteration: raise ValueError( 'Cannot convert empty artifact collection into Channel') else: raise ValueError('Invalid source to be a channel: {}'.format(source))
def testComponentSpecWithRuntimeParam(self): proto_str = '{"splits": [{"name": "name1", "pattern": "pattern1"}]}' param_proto = data_types.RuntimeParameter(name='proto', ptype=str, default=proto_str) param_int = data_types.RuntimeParameter(name='int', ptype=int) input_channel = Channel(type=_InputArtifact) output_channel = Channel(type=_OutputArtifact) spec = _BasicComponentSpec(folds=param_int, proto=param_proto, input=input_channel, output=output_channel) self.assertIsInstance(spec.exec_properties['folds'], data_types.RuntimeParameter) self.assertIsInstance(spec.exec_properties['proto'], data_types.RuntimeParameter) self.assertEqual(spec.exec_properties['proto'].default, proto_str)
def testComponentspecMissingArguments(self): class SimpleComponentSpec(ComponentSpec): PARAMETERS = { 'x': ExecutionParameter(type=int), 'y': ExecutionParameter(type=int, optional=True), } INPUTS = {'z': ChannelParameter(type_name='Z')} OUTPUTS = {} with self.assertRaisesRegexp(ValueError, 'Missing argument'): _ = SimpleComponentSpec(x=10) with self.assertRaisesRegexp(ValueError, 'Missing argument'): _ = SimpleComponentSpec(z=Channel(type_name='Z')) # Okay since y is optional. _ = SimpleComponentSpec(x=10, z=Channel(type_name='Z'))
def testValidateDynamicExecPhOperator(self): with self.assertRaises(ValueError): invalid_dynamic_exec_ph = Channel(type=_MyType).future() compiler_utils.validate_dynamic_exec_ph_operator( invalid_dynamic_exec_ph) with self.assertRaises(ValueError): invalid_dynamic_exec_ph = Channel(type=_MyType).future()[0].uri compiler_utils.validate_dynamic_exec_ph_operator( invalid_dynamic_exec_ph) with self.assertRaises(ValueError): invalid_dynamic_exec_ph = Channel( type=_MyType).future()[0].value + Channel( type=_MyType).future()[0].value compiler_utils.validate_dynamic_exec_ph_operator( invalid_dynamic_exec_ph) valid_dynamic_exec_ph = Channel(type=_MyType).future()[0].value compiler_utils.validate_dynamic_exec_ph_operator(valid_dynamic_exec_ph)
def testUnwrapChannelDict(self): instance_a = Artifact('MyTypeName') instance_b = Artifact('MyTypeName') channel_dict = { 'id': Channel('MyTypeName', artifacts=[instance_a, instance_b]) } result = channel_utils.unwrap_channel_dict(channel_dict) self.assertDictEqual(result, {'id': [instance_a, instance_b]})
def testUnwrapChannelDict(self): instance_a = _MyArtifact() instance_b = _MyArtifact() channel_dict = { 'id': Channel(_MyArtifact, artifacts=[instance_a, instance_b]) } result = channel_utils.unwrap_channel_dict(channel_dict) self.assertDictEqual(result, {'id': [instance_a, instance_b]})
class ChannelWrappedPlaceholderTest(parameterized.TestCase, tf.test.TestCase): @parameterized.named_parameters( { 'testcase_name': 'two_sides_placeholder', 'left': Channel(type=_MyType).future().value, 'right': Channel(type=_MyType).future().value, }, { 'testcase_name': 'left_side_placeholder_right_side_string', 'left': Channel(type=_MyType).future().value, 'right': '#', }, { 'testcase_name': 'left_side_string_right_side_placeholder', 'left': 'http://', 'right': Channel(type=_MyType).future().value, }, ) def testConcat(self, left, right): self.assertIsInstance(left + right, ph.ChannelWrappedPlaceholder) def testEncodeWithKeys(self): channel = Channel(type=_MyType) channel_future = channel.future()[0].value actual_pb = channel_future.encode_with_keys( lambda channel: channel.type_name) expected_pb = text_format.Parse( """ operator { artifact_value_op { expression { operator { index_op { expression { placeholder { key: "MyTypeName" } } } } } } } """, placeholder_pb2.PlaceholderExpression()) self.assertProtoEquals(actual_pb, expected_pb) self.assertIsNone(channel_future._key)
def testProtoTypeCheck(self): param = data_types.RuntimeParameter(name='split-1', ptype=Text) # Dict representation of a example_gen_pb2.Input proto message. # The second split has int-typed pattern, which is wrong. proto = dict(splits=[ dict(name=param, pattern='pattern1'), dict(name='name2', pattern=42), dict(name='name3', pattern='pattern3'), ]) input_channel = Channel(type_name='InputType') output_channel = Channel(type_name='OutputType') with self.assertRaisesRegexp( ParseError, 'Failed to parse .* field: expected string or ' '(bytes-like object|buffer)'): spec = _BasicComponentSpec( # pylint: disable=unused-variable folds=10, proto=proto, input=input_channel, output=output_channel)
def testJsonRoundTrip(self): channel = Channel( type=_MyType, additional_properties={ 'string_value': metadata_store_pb2.Value(string_value='forty-two') }, additional_custom_properties={ 'int_value': metadata_store_pb2.Value(int_value=42) }) serialized = channel.to_json_dict() rehydrated = Channel.from_json_dict(serialized) self.assertIs(channel.type, rehydrated.type) self.assertEqual(channel.type_name, rehydrated.type_name) self.assertEqual(channel.additional_properties, rehydrated.additional_properties) self.assertEqual(channel.additional_custom_properties, rehydrated.additional_custom_properties)
def testOptionalOutputs(self): class SpecWithOptionalOutput(ComponentSpec): PARAMETERS = {} INPUTS = {} OUTPUTS = {'x': ChannelParameter(type=_Z, optional=True)} optional_not_specified = SpecWithOptionalOutput() self.assertNotIn('x', optional_not_specified.outputs.keys()) optional_specified = SpecWithOptionalOutput(x=Channel(type=_Z)) self.assertIn('x', optional_specified.outputs.keys())
def testEquals(self): left = Channel(type=_MyType) right = Channel(type=_MyType) pred = left.future().value == right.future().value actual_pb = pred.encode() self.assertEqual(actual_pb.operator.compare_op.op, placeholder_pb2.ComparisonOperator.Operation.EQUAL)
def build_input_artifact_spec( channel_spec: channel.Channel ) -> pipeline_pb2.ComponentInputsSpec.ArtifactSpec: """Builds artifact type spec for an input channel.""" artifact_instance = channel_spec.type() result = pipeline_pb2.ComponentInputsSpec.ArtifactSpec() result.artifact_type.CopyFrom( pipeline_pb2.ArtifactTypeSchema( instance_schema=get_artifact_schema(artifact_instance))) _validate_properties_schema( instance_schema=result.artifact_type.instance_schema, properties=channel_spec.type.PROPERTIES) return result
def build_output_artifact_spec( channel_spec: channel.Channel ) -> pipeline_pb2.ComponentOutputsSpec.ArtifactSpec: """Builds artifact type spec for an output channel.""" # We use the first artifact instance if available from channel, otherwise # create one. artifacts = list(channel_spec.get()) artifact_instance = artifacts[0] if artifacts else channel_spec.type() result = pipeline_pb2.ComponentOutputsSpec.ArtifactSpec() result.artifact_type.CopyFrom( pipeline_pb2.ArtifactTypeSchema( instance_schema=get_artifact_schema(artifact_instance))) _validate_properties_schema( instance_schema=result.artifact_type.instance_schema, properties=channel_spec.type.PROPERTIES) struct_proto = pack_artifact_properties(artifact_instance) if struct_proto: result.metadata.CopyFrom(struct_proto) return result
def build_output_artifact_spec( channel_spec: channel.Channel ) -> pipeline_pb2.TaskOutputsSpec.OutputArtifactSpec: """Builds the Kubeflow pipeline output artifact spec from TFX channel spec.""" artifact_instance = channel_spec.type() result = pipeline_pb2.TaskOutputsSpec.OutputArtifactSpec() result.artifact_type.CopyFrom( pipeline_pb2.ArtifactTypeSchema( instance_schema=get_artifact_schema(artifact_instance))) for k, v in convert_from_tfx_properties( artifact_instance.mlmd_artifact.properties).items(): result.properties[k].CopyFrom(v) for k, v in convert_from_tfx_properties( artifact_instance.mlmd_artifact.custom_properties).items(): result.custom_properties[k].CopyFrom(v) return result
def testDoubleNegation(self): """Treat `not(not(a))` as `a`.""" channel_1 = Channel(type=_MyType) channel_2 = Channel(type=_MyType) pred = channel_1.future().value < channel_2.future().value not_not_pred = ph.logical_not(ph.logical_not(pred)) channel_to_key_map = { channel_1: 'channel_1_key', channel_2: 'channel_2_key', } actual_pb = not_not_pred.encode_with_keys( lambda channel: channel_to_key_map[channel]) expected_pb = text_format.Parse( """ operator { compare_op { lhs { operator { artifact_value_op { expression { operator { index_op { expression { placeholder { key: "channel_1_key" } } } } } } } } rhs { operator { artifact_value_op { expression { operator { index_op { expression { placeholder { key: "channel_2_key" } } } } } } } } op: LESS_THAN } } """, placeholder_pb2.PlaceholderExpression()) self.assertProtoEquals(actual_pb, expected_pb)
def testPredicateDependentChannels(self): int1 = Channel(type=standard_artifacts.Integer) int2 = Channel(type=standard_artifacts.Integer) pred1 = int1.future().value == 1 pred2 = int1.future().value == int2.future().value pred3 = ph.logical_not(pred1) pred4 = ph.logical_and(pred1, pred2) self.assertEqual(set(pred1.dependent_channels()), {int1}) self.assertEqual(set(pred2.dependent_channels()), {int1, int2}) self.assertEqual(set(pred3.dependent_channels()), {int1}) self.assertEqual(set(pred4.dependent_channels()), {int1, int2})
def testEncodeWithKeys(self): channel_1 = Channel(type=_MyType) channel_2 = Channel(type=_MyType) pred = channel_1.future().value > channel_2.future().value channel_to_key_map = { channel_1: 'channel_1_key', channel_2: 'channel_2_key', } actual_pb = pred.encode_with_keys( lambda channel: channel_to_key_map[channel]) expected_pb = text_format.Parse( """ operator { compare_op { lhs { operator { artifact_value_op { expression { operator { index_op { expression { placeholder { key: "channel_1_key" } } } } } } } } rhs { operator { artifact_value_op { expression { operator { index_op { expression { placeholder { key: "channel_2_key" } } } } } } } } op: GREATER_THAN } } """, placeholder_pb2.PlaceholderExpression()) self.assertProtoEquals(actual_pb, expected_pb)
def as_channel(artifacts: Iterable[Artifact]) -> Channel: """Converts artifact collection of the same artifact type into a Channel. Args: artifacts: An iterable of Artifact. Returns: A static Channel containing the source artifact collection. Raises: ValueError when source is not a non-empty iterable of Artifact. """ try: first_element = next(iter(artifacts)) if isinstance(first_element, Artifact): return Channel(type=first_element.type).set_artifacts(artifacts) else: raise ValueError('Invalid artifact iterable: {}'.format(artifacts)) except StopIteration: raise ValueError('Cannot convert empty artifact iterable into Channel')
def testEncode(self): channel_1 = Channel(type=_MyType) channel_2 = Channel(type=_MyType) pred = channel_1.future().value > channel_2.future().value actual_pb = pred.encode() expected_pb = text_format.Parse( """ operator { compare_op { lhs { operator { artifact_value_op { expression { operator { index_op { expression { placeholder {} } } } } } } } rhs { operator { artifact_value_op { expression { operator { index_op { expression { placeholder {} } } } } } } } op: GREATER_THAN } } """, placeholder_pb2.PlaceholderExpression()) self.assertProtoEquals(actual_pb, expected_pb)
def type_check(self, arg_name: Text, value: Channel): if not isinstance(value, Channel): raise TypeError( 'Argument %s should be a Channel of type_name %r (got %s).' % (arg_name, self.type_name, value)) value.type_check(self.type_name)
def testInvalidChannelType(self): instance_a = _MyType() instance_b = _MyType() with self.assertRaises(ValueError): Channel(_AnotherType).set_artifacts([instance_a, instance_b])
def testStringTypeNameNotAllowed(self): with self.assertRaises(ValueError): Channel('StringTypeName')