def _assert_in_default_state(t): t.assertIs(ds_context._get_default_replica_context(), ds_context.get_replica_context()) t.assertIs(None, ds_context.get_cross_replica_context()) t.assertFalse(ds_context.in_cross_replica_context()) t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy()) t.assertFalse(ds_context.has_strategy())
def testScopeMostlyNoOp(self): _assert_in_default_state(self) test_strategy = _TestStrategy2() with test_strategy.scope(): variable_scope.variable(1.0, name="before") default_strategy = ds_context._get_default_strategy() scope = default_strategy.scope() with scope: _assert_in_default_state(self) with test_strategy.scope(): with self.assertRaisesRegexp( RuntimeError, "Mixing different tf.distribute.Strategy objects"): variable_scope.variable(1.0, name="error") with scope: _assert_in_default_state(self) with test_strategy.scope(): with self.assertRaisesRegexp( RuntimeError, "Mixing different tf.distribute.Strategy objects"): variable_scope.variable(1.0, name="also_error") _assert_in_default_state(self) _assert_in_default_state(self) with test_strategy.scope(): variable_scope.variable(1.0, name="after")
def _assert_in_default_state(t): t.assertIs(ds_context._get_default_replica_context(), ds_context.get_replica_context()) t.assertIs(None, ds_context.get_cross_replica_context()) t.assertFalse(ds_context.in_cross_replica_context()) t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy()) t.assertFalse(ds_context.has_strategy())
def merge_fn(dist, s): self.assertIs(ds_context._get_default_strategy(), dist) self.assertIs(None, ds_context.get_replica_context()) self.assertIs(dist, ds_context.get_cross_replica_context()) self.assertTrue(ds_context.in_cross_replica_context()) self.assertIs(dist, ds_context.get_strategy()) self.assertFalse(ds_context.has_strategy()) return "foo_" + s
def merge_fn(dist, s): self.assertIs(ds_context._get_default_strategy(), dist) self.assertIs(None, ds_context.get_replica_context()) self.assertIs(dist, ds_context.get_cross_replica_context()) self.assertTrue(ds_context.in_cross_replica_context()) self.assertIs(dist, ds_context.get_strategy()) self.assertFalse(ds_context.has_strategy()) return "foo_" + s
def _create_strategy(self, num_shards): if num_shards > 1: strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( self.cluster_resolver, variable_partitioner=sharded_variable.FixedShardsPartitioner( num_shards)) else: strategy = ds_context._get_default_strategy() return strategy
def testExperimentalRunV2(self): default_strategy = ds_context._get_default_strategy() dataset = dataset_ops.Dataset.range(10).batch(2) iterator = default_strategy.extended._make_dataset_iterator(dataset) next_val = iterator.get_next() def train_step(input_data): return input_data for _ in range(2): default_strategy.experimental_run_v2(train_step, args=(next_val, ))
def testDistributedDatasetsFromFunction(self): default_strategy = ds_context._get_default_strategy() if context.executing_eagerly(): dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) dist_dataset_from_func = \ default_strategy.experimental_distribute_datasets_from_function( dataset_fn) next_val = next(iter(dist_dataset_from_func)) self.assertAllEqual([0, 1], self.evaluate(next_val)) else: dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) dist_dataset_from_func = \ default_strategy.experimental_distribute_datasets_from_function( dataset_fn) dataset_ops.make_initializable_iterator(dist_dataset_from_func)
def testDistributedDatasets(self): default_strategy = ds_context._get_default_strategy() if context.executing_eagerly(): dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) dist_dataset = default_strategy.experimental_distribute_dataset( dataset_fn(distribute_lib.InputContext())) next_val = next(iter(dist_dataset)) else: dataset_fn = lambda _: dataset_ops.DatasetV1.range(10).batch(2) dist_dataset = default_strategy.experimental_distribute_dataset( dataset_fn(distribute_lib.InputContext())) iterator = dist_dataset.make_initializable_iterator() self.evaluate(iterator.initializer) next_val = iterator.get_next() self.assertAllEqual([0, 1], self.evaluate(next_val))
def testDistributedDatasetsFromFunction(self): default_strategy = ds_context._get_default_strategy() if context.executing_eagerly(): dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) dist_dataset_from_func = \ default_strategy.experimental_distribute_datasets_from_function( dataset_fn) next_val = next(iter(dist_dataset_from_func)) self.assertAllEqual([0, 1], self.evaluate(next_val)) else: dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) with self.assertRaisesRegexp(RuntimeError, "only supported when eager execution is " "enabled"): dist_dataset_from_func = \ default_strategy.experimental_distribute_datasets_from_function( dataset_fn)