コード例 #1
0
ファイル: compiler_utils_test.py プロジェクト: jeongukjae/tfx
    def testIsResolver(self):
        resv = resolver.Resolver(strategy_class=latest_blessed_model_resolver.
                                 LatestBlessedModelResolver)
        self.assertTrue(compiler_utils.is_resolver(resv))
        resv = legacy_resolver_node.ResolverNode(
            resolver_class=latest_blessed_model_resolver.
            LatestBlessedModelResolver)
        self.assertTrue(compiler_utils.is_resolver(resv))

        example_gen = CsvExampleGen(input_base="data_path")
        self.assertFalse(compiler_utils.is_resolver(example_gen))
コード例 #2
0
    def testIsResolver(self):
        resv = resolver.Resolver(instance_name="test_resolver_name",
                                 strategy_class=latest_blessed_model_resolver.
                                 LatestBlessedModelResolver)
        self.assertTrue(compiler_utils.is_resolver(resv))
        resv = legacy_resolver_node.ResolverNode(
            instance_name="test_resolver_name",
            resolver_class=latest_blessed_model_resolver.
            LatestBlessedModelResolver)
        self.assertTrue(compiler_utils.is_resolver(resv))

        example_gen = CsvExampleGen(input=external_input("data_path"))
        self.assertFalse(compiler_utils.is_resolver(example_gen))
コード例 #3
0
ファイル: resolver_node_test.py プロジェクト: zzhmtxxhh/tfx
 def testImporterDefinition(self):
   channel_to_resolve = types.Channel(type=standard_artifacts.Examples)
   rnode = resolver_node.ResolverNode(
       instance_name='my_resolver',
       resolver_class=latest_artifacts_resolver.LatestArtifactsResolver,
       resolver_configs={'desired_num_of_artifacts': 5},
       channel_to_resolve=channel_to_resolve)
   self.assertDictEqual(
       rnode.exec_properties, {
           resolver_node.RESOLVER_CLASS:
               latest_artifacts_resolver.LatestArtifactsResolver,
           resolver_node.RESOLVER_CONFIGS: {'desired_num_of_artifacts': 5}
       })
   self.assertEqual(rnode.inputs.get_all()['channel_to_resolve'],
                    channel_to_resolve)
   self.assertEqual(rnode.outputs.get_all()['channel_to_resolve'].type_name,
                    channel_to_resolve.type_name)