class ClusterParametersShouldFailTest(test.TestCase, parameterized.TestCase): @framework_combinations.generate( # pylint: disable=redundant-keyword-arg framework_combinations.combine( ds1=combinations.NamedDistribution( "Strategy1", lambda: None, has_chief=True, num_workers=2), ds2=combinations.NamedDistribution( "Strategy2", lambda: None, has_chief=True, num_workers=2), ), test_combinations=(combinations.ClusterCombination(),)) def testMultipleDistributionMultiWorker(self, ds1, ds2): # combinations library should raise an exception. pass
class ClusterParametersShouldFailTest(test.TestCase, parameterized.TestCase): @framework_combinations.generate( framework_combinations.combine( ds1=combinations.NamedDistribution("Strategy1", lambda: None, has_chief=True, num_workers=2), ds2=combinations.NamedDistribution("Strategy2", lambda: None, has_chief=True, num_workers=2), ), test_combinations=(combinations.ClusterCombination(), )) def testMultipleDistributionMultiWorker(self, ds1, ds2): # combinations library should raise an exception. pass @combinations.generate(combinations.combine(num_workers=2, )) def testUseWithoutStrategy(self): # There's no perfect way to check if the test runs in a subprocess. We # approximate by checking the presence of TF_CONFIG, which is normally not # set to the main process. self.assertNotEqual(os.getenv("TF_CONFIG"), "") raise ValueError("actually run")
class ClusterCombinationTest(test.TestCase, parameterized.TestCase): # For this test we need to use `framework.test_combinations` because our # `generate` eats the cluster parameters. # # Note that we don't have a standalone combination for ClusterParameters, so # we should use GPUCombination which contains it. @framework_combinations.generate( framework_combinations.combine(distribution=[ combinations.NamedDistribution("HasClusterParams", lambda: None, has_chief=True, num_workers=2), ]), test_combinations=(combinations.ClusterCombination(), )) def testClusterParams(self, distribution, has_chief, num_workers): self.assertTrue(has_chief) self.assertEqual(num_workers, 2) @framework_combinations.generate( framework_combinations.combine(distribution=[ combinations.NamedDistribution("NoClusterParams", lambda: None), ]), test_combinations=(combinations.ClusterCombination(), )) def testClusterParamsHasDefault(self, distribution, has_chief, num_workers): self.assertFalse(has_chief) self.assertEqual(num_workers, 1) @framework_combinations.generate( framework_combinations.combine(v=1), test_combinations=(combinations.ClusterCombination(), )) def testClusterParamsNoStrategy(self, v, has_chief, num_workers): self.assertFalse(has_chief) self.assertEqual(num_workers, 1) @framework_combinations.generate( framework_combinations.combine(distribution=[ combinations.NamedDistribution("WithClusterParams", lambda: None, has_chief=True, num_workers=2), combinations.NamedDistribution("WithoutClusterParams", lambda: None), ]), test_combinations=(combinations.ClusterCombination(), )) def testClusterParamsAreOptional(self, distribution): # If combinations library doesn't raise an exception, the test is passed. pass @framework_combinations.generate( framework_combinations.combine( ds1=combinations.NamedDistribution("Strategy1", lambda: None, has_chief=True, num_workers=0), ds2=combinations.NamedDistribution("Strategy2", lambda: None, has_chief=False, num_workers=1), ds3=combinations.NamedDistribution("Strategy3", lambda: None, has_chief=True, num_workers=0), ), test_combinations=(combinations.ClusterCombination(), )) def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3): # If combinations library doesn't raise an exception, the test is passed. pass @combinations.generate(combinations.combine(num_workers=2, )) def testUseWithoutStrategy(self): # There's no perfect way to check if the test runs in a subprocess. We # approximate by checking the presence of TF_CONFIG, which is normally not # set to the main process. self.assertNotEqual(os.getenv("TF_CONFIG"), "")
class ClusterCombinationTest(test.TestCase, parameterized.TestCase): # For this test we need to use `framework.test_combinations` because our # `generate` eats the cluster parameters. # # Note that we don't have a standalone combination for ClusterParameters, so # we should use GPUCombination which contains it. @framework_combinations.generate( framework_combinations.combine(distribution=[ combinations.NamedDistribution("HasClusterParams", lambda: None, has_chief=True, num_workers=2), ]), test_combinations=(combinations.ClusterCombination(), )) def testClusterParams(self, distribution, has_chief, num_workers): self.assertTrue(has_chief) self.assertEqual(num_workers, 2) @framework_combinations.generate( framework_combinations.combine(distribution=[ combinations.NamedDistribution("NoClusterParams", lambda: None), ]), test_combinations=(combinations.ClusterCombination(), )) def testClusterParamsHasDefault(self, distribution, has_chief, num_workers): self.assertFalse(has_chief) self.assertEqual(num_workers, 1) @framework_combinations.generate( framework_combinations.combine(v=1), test_combinations=(combinations.ClusterCombination(), )) def testClusterParamsNoStrategy(self, v, has_chief, num_workers): self.assertFalse(has_chief) self.assertEqual(num_workers, 1) @framework_combinations.generate( framework_combinations.combine(distribution=[ combinations.NamedDistribution("WithClusterParams", lambda: None, has_chief=True, num_workers=2), combinations.NamedDistribution("WithoutClusterParams", lambda: None), ]), test_combinations=(combinations.ClusterCombination(), )) def testClusterParamsAreOptional(self, distribution): # If combinations library doesn't raise an exception, the test is passed. pass @framework_combinations.generate( framework_combinations.combine( ds1=combinations.NamedDistribution("Strategy1", lambda: None, has_chief=True, num_workers=0), ds2=combinations.NamedDistribution("Strategy2", lambda: None, has_chief=False, num_workers=1), ds3=combinations.NamedDistribution("Strategy3", lambda: None, has_chief=True, num_workers=0), ), test_combinations=(combinations.ClusterCombination(), )) def testMultipleDistributionSingleWorker(self, ds1, ds2, ds3): # If combinations library doesn't raise an exception, the test is passed. pass