Beispiel #1
0
 def testForEach_NestedForEach_NotImplemented(self):
   with self.assertRaises(NotImplementedError):
     a = A()
     b = B()
     with for_each.ForEach(a.outputs['aa']) as aa:
       with for_each.ForEach(b.outputs['bb']) as bb:
         c = C(aa=aa, bb=bb)  # pylint: disable=unused-variable
Beispiel #2
0
  def testForEach_DifferentLoop_HasDifferentContext(self):
    a = A()
    b = B()
    with for_each.ForEach(a.outputs['aa']) as aa1:
      c1 = C(aa=aa1, bb=b.outputs['bb'])  # pylint: disable=unused-variable
    with for_each.ForEach(a.outputs['aa']) as aa2:
      c2 = C(aa=aa2, bb=b.outputs['bb'])  # pylint: disable=unused-variable

    context1 = context_manager.get_contexts(c1)[-1]
    context2 = context_manager.get_contexts(c2)[-1]
    self.assertNotEqual(context1.id, context2.id)
Beispiel #3
0
  def testForEach_LoopVariableNotUsed_Disallowed(self):
    with self.subTest('Not using at all'):
      with self.assertRaises(ValueError):
        a = A()
        with for_each.ForEach(a.outputs['aa']) as aa:  # pylint: disable=unused-variable
          b = B()  # aa is not used. # pylint: disable=unused-variable

    with self.subTest('Source channel is not a loop variable.'):
      with self.assertRaises(ValueError):
        a = A()
        with for_each.ForEach(a.outputs['aa']) as aa:
          b = B(aa=a.outputs['aa'])  # Should use loop var "aa" directly.
Beispiel #4
0
def create_test_pipeline():
    """Creates a sample pipeline with ForEach context."""

    example_gen = CsvExampleGen(input_base='/data/mydummy_dataset')

    with for_each.ForEach(example_gen.outputs['examples']) as each_example:
        statistics_gen = StatisticsGen(examples=each_example)

    latest_stats_resolver = resolver.Resolver(
        statistics=statistics_gen.outputs['statistics'],
        strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
    ).with_id('latest_stats_resolver')

    schema_gen = SchemaGen(
        statistics=latest_stats_resolver.outputs['statistics'])

    with for_each.ForEach(example_gen.outputs['examples']) as each_example:
        trainer = Trainer(
            module_file='/src/train.py',
            examples=each_example,
            schema=schema_gen.outputs['schema'],
            train_args=trainer_pb2.TrainArgs(num_steps=2000),
        )

    with for_each.ForEach(trainer.outputs['model']) as each_model:
        pusher = Pusher(
            model=each_model,
            push_destination=pusher_pb2.PushDestination(
                filesystem=pusher_pb2.PushDestination.Filesystem(
                    base_directory='/models')),
        )

    return pipeline.Pipeline(
        pipeline_name='foreach',
        pipeline_root='/tfx/pipelines/foreach',
        components=[
            example_gen,
            statistics_gen,
            latest_stats_resolver,
            schema_gen,
            trainer,
            pusher,
        ],
        enable_cache=True,
        execution_mode=pipeline.ExecutionMode.SYNC,
    )
Beispiel #5
0
  def testForEach_As_GivesLoopVariable(self):
    a = A().with_id('Apple')
    with for_each.ForEach(a.outputs['aa']) as aa:
      pass

    self.assertIsInstance(aa, types.LoopVarChannel)
    self.assertEqual(aa.context_id, 'ForEachContext:1')
    self.assertEqual(aa.type, AA)
    self.assertEqual(aa.wrapped.producer_component_id, 'Apple')
    self.assertEqual(aa.wrapped.output_key, 'aa')
Beispiel #6
0
 def testForEach_MultipleNodes_NotImplemented(self):
   with self.assertRaises(NotImplementedError):
     a = A()
     with for_each.ForEach(a.outputs['aa']) as aa:
       b = B(aa=aa)
       c = C(bb=b.outputs['bb'])  # pylint: disable=unused-variable