self.assertEqual(self.evaluate(val.read_value()), 0)
        for i in range(0, len(devices)):
            with distribute_lib.ReplicaContext(distribution, i):
                val = v1._get()
                self.assertIsInstance(val, packed.PackedVarAndDevice)
                self.assertEqual(val.device, devices[i])
                self.assertEqual(self.evaluate(val.read_value()), i)

    def testIgnorePackedVariableInSaveContext(self, distribution):
        distribution._enable_packed_variable_in_eager_mode = True
        with distribution.scope():
            v = variables_lib.Variable(0)
            self.assertIsInstance(v._packed_variable,
                                  packed.PackedDistributedVariable)

        options = save_options.SaveOptions()
        with save_context.save_context(options):
            self.assertIsNone(v._packed_variable)


def _make_index_slices(values, indices, dense_shape=None):
    if dense_shape:
        dense_shape = array_ops.identity(dense_shape)
    return indexed_slices.IndexedSlices(array_ops.identity(values),
                                        array_ops.identity(indices),
                                        dense_shape)


if __name__ == "__main__":
    ds_test_util.main()
                "context with MultiWorkerMirroredStrategy.")
        with distribution.scope():
            w_assign, w_apply, ema_w = distribution.run(
                self._ema_replica_fn_graph)
        self.assertEqual(ema_w.name, "w/ExponentialMovingAverage:0")
        self.evaluate(variables.global_variables_initializer())
        self.evaluate(distribution.experimental_local_results(w_apply))
        self.evaluate(distribution.experimental_local_results(w_assign))
        self.evaluate(distribution.experimental_local_results(w_apply))
        self.assertAllClose(
            self.evaluate(distribution.experimental_local_results(ema_w))[0],
            [0.89999998])

    @combinations.generate(all_combinations)
    def testCrossReplicaContextGraph(self, distribution):
        with distribution.scope():
            w_assign, w_apply, ema_w = self._ema_replica_fn_graph()
        self.assertEqual(ema_w.name, "w/ExponentialMovingAverage:0")
        self.evaluate(variables.global_variables_initializer())
        self.evaluate(distribution.experimental_local_results(w_apply))
        self.evaluate(distribution.experimental_local_results(w_assign))
        self.evaluate(distribution.experimental_local_results(w_apply))
        self.assertAllClose(
            self.evaluate(distribution.experimental_local_results(ema_w))[0],
            [0.89999998])


if __name__ == "__main__":
    # TODO(b/172304955): enable logical devices.
    test_util.main(config_logical_devices=False)
Ejemplo n.º 3
0
    @combinations.generate(
        combinations.combine(strategy=[
            strategy_combinations.multi_worker_mirrored_2x1_cpu,
            strategy_combinations.multi_worker_mirrored_2x1_gpu,
            strategy_combinations.multi_worker_mirrored_2x2_gpu,
            strategy_combinations.multi_worker_mirrored_4x1_cpu,
        ]))
    def testMultiWorkerMirrored(self, strategy):
        self.assertIsInstance(
            strategy,
            collective_all_reduce_strategy.CollectiveAllReduceStrategy)

    @combinations.generate(
        combinations.combine(strategy=[
            strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
            strategy_combinations.central_storage_strategy_with_two_gpus,
        ]))
    def testCentralStorage(self, strategy):
        self.assertIsInstance(strategy,
                              central_storage_strategy.CentralStorageStrategy)

    @combinations.generate(
        combinations.combine(strategy=strategy_combinations.tpu_strategies))
    def testTPU(self, strategy):
        self.assertIsInstance(strategy, tpu_strategy.TPUStrategy)


if __name__ == "__main__":
    test_util.main()
Ejemplo n.º 4
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for tfr mnist training example."""

from absl.testing import parameterized

from tensorflow.compiler.mlir.tfr.examples.mnist import mnist_train
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util as distribute_test_util
from tensorflow.python.framework import test_util

strategies = [
    strategy_combinations.one_device_strategy,
    strategy_combinations.one_device_strategy_gpu,
    strategy_combinations.tpu_strategy,
]


class MnistTrainTest(test_util.TensorFlowTestCase, parameterized.TestCase):
    @combinations.generate(combinations.combine(strategy=strategies))
    def testMnistTrain(self, strategy):
        accuracy = mnist_train.main(strategy)
        self.assertGreater(accuracy, 0.7, 'accuracy sanity check')


if __name__ == '__main__':
    distribute_test_util.main()