Exemple #1
0
def _assert_in_default_state(t):
    t.assertIs(ds_context._get_default_replica_context(),
               ds_context.get_replica_context())
    t.assertIs(None, ds_context.get_cross_replica_context())
    t.assertFalse(ds_context.in_cross_replica_context())
    t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
    t.assertFalse(ds_context.has_strategy())
Exemple #2
0
    def testScopeMostlyNoOp(self):
        _assert_in_default_state(self)

        test_strategy = _TestStrategy2()
        with test_strategy.scope():
            variable_scope.variable(1.0, name="before")

        default_strategy = ds_context._get_default_strategy()
        scope = default_strategy.scope()
        with scope:
            _assert_in_default_state(self)

            with test_strategy.scope():
                with self.assertRaisesRegexp(
                        RuntimeError,
                        "Mixing different tf.distribute.Strategy objects"):
                    variable_scope.variable(1.0, name="error")

            with scope:
                _assert_in_default_state(self)

                with test_strategy.scope():
                    with self.assertRaisesRegexp(
                            RuntimeError,
                            "Mixing different tf.distribute.Strategy objects"):
                        variable_scope.variable(1.0, name="also_error")

            _assert_in_default_state(self)

        _assert_in_default_state(self)
        with test_strategy.scope():
            variable_scope.variable(1.0, name="after")
def _assert_in_default_state(t):
  t.assertIs(ds_context._get_default_replica_context(),
             ds_context.get_replica_context())
  t.assertIs(None, ds_context.get_cross_replica_context())
  t.assertFalse(ds_context.in_cross_replica_context())
  t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
  t.assertFalse(ds_context.has_strategy())
Exemple #4
0
 def merge_fn(dist, s):
     self.assertIs(ds_context._get_default_strategy(), dist)
     self.assertIs(None, ds_context.get_replica_context())
     self.assertIs(dist, ds_context.get_cross_replica_context())
     self.assertTrue(ds_context.in_cross_replica_context())
     self.assertIs(dist, ds_context.get_strategy())
     self.assertFalse(ds_context.has_strategy())
     return "foo_" + s
 def merge_fn(dist, s):
   self.assertIs(ds_context._get_default_strategy(), dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertFalse(ds_context.has_strategy())
   return "foo_" + s
 def _create_strategy(self, num_shards):
     if num_shards > 1:
         strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
             self.cluster_resolver,
             variable_partitioner=sharded_variable.FixedShardsPartitioner(
                 num_shards))
     else:
         strategy = ds_context._get_default_strategy()
     return strategy
Exemple #7
0
    def testExperimentalRunV2(self):
        default_strategy = ds_context._get_default_strategy()
        dataset = dataset_ops.Dataset.range(10).batch(2)
        iterator = default_strategy.extended._make_dataset_iterator(dataset)
        next_val = iterator.get_next()

        def train_step(input_data):
            return input_data

        for _ in range(2):
            default_strategy.experimental_run_v2(train_step, args=(next_val, ))
Exemple #8
0
 def testDistributedDatasetsFromFunction(self):
     default_strategy = ds_context._get_default_strategy()
     if context.executing_eagerly():
         dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
         dist_dataset_from_func = \
             default_strategy.experimental_distribute_datasets_from_function(
                 dataset_fn)
         next_val = next(iter(dist_dataset_from_func))
         self.assertAllEqual([0, 1], self.evaluate(next_val))
     else:
         dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
         dist_dataset_from_func = \
           default_strategy.experimental_distribute_datasets_from_function(
               dataset_fn)
         dataset_ops.make_initializable_iterator(dist_dataset_from_func)
Exemple #9
0
 def testDistributedDatasets(self):
     default_strategy = ds_context._get_default_strategy()
     if context.executing_eagerly():
         dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
         dist_dataset = default_strategy.experimental_distribute_dataset(
             dataset_fn(distribute_lib.InputContext()))
         next_val = next(iter(dist_dataset))
     else:
         dataset_fn = lambda _: dataset_ops.DatasetV1.range(10).batch(2)
         dist_dataset = default_strategy.experimental_distribute_dataset(
             dataset_fn(distribute_lib.InputContext()))
         iterator = dist_dataset.make_initializable_iterator()
         self.evaluate(iterator.initializer)
         next_val = iterator.get_next()
     self.assertAllEqual([0, 1], self.evaluate(next_val))
 def testDistributedDatasetsFromFunction(self):
   default_strategy = ds_context._get_default_strategy()
   if context.executing_eagerly():
     dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
     dist_dataset_from_func = \
         default_strategy.experimental_distribute_datasets_from_function(
             dataset_fn)
     next_val = next(iter(dist_dataset_from_func))
     self.assertAllEqual([0, 1], self.evaluate(next_val))
   else:
     dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
     with self.assertRaisesRegexp(RuntimeError,
                                  "only supported when eager execution is "
                                  "enabled"):
       dist_dataset_from_func = \
         default_strategy.experimental_distribute_datasets_from_function(
             dataset_fn)