Ejemplo n.º 1
0
            strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
            task_id, attempt = get_attempt(strategy, attempts)

            @tf.function
            def replica_fn():
                ctx = tf.distribute.get_replica_context()
                # Use a large tensor because small tensor may hang regardless when the
                # worker recovers.
                value = tf.ones((64, 64))
                ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value, value])

            strategy.run(replica_fn)
            # worker-1 dies here.
            if attempt == 1 and task_id == 1:
                quick_exit(1)
            strategy.run(replica_fn)

        cluster_spec = multi_worker_test_base.create_cluster_spec(
            num_workers=2)
        attempts = multi_process_runner.manager().dict()
        mpr = multi_process_runner.MultiProcessRunner(worker_fn,
                                                      cluster_spec,
                                                      args=(attempts, ),
                                                      auto_restart=True)
        mpr.start()
        mpr.join(timeout=90)


if __name__ == "__main__":
    combinations.main()
Ejemplo n.º 2
0
        @def_function.function
        def step_fn(i):
            metric.update_state(i)

        for i in dataset:
            distribution.run(step_fn, args=(i, ))

        # This should be the mean of integers 0-9 which has a sum of 45 and a count
        # of 10 resulting in mean of 4.5.
        self.assertEqual(metric.result().numpy(), 4.5)

    @ds_combinations.generate(
        combinations.combine(distribution=strategy_combinations.all_strategies,
                             mode=["eager"]))
    def test_update_keras_metric_outside_strategy_scope_cross_replica(
            self, distribution):
        metric = metrics.Mean("test_metric", dtype=np.float32)

        with distribution.scope():
            for i in range(10):
                metric.update_state(i)

        # This should be the mean of integers 0-9 which has a sum of 45 and a count
        # of 10 resulting in mean of 4.5.
        self.assertEqual(metric.result().numpy(), 4.5)


if __name__ == "__main__":
    ds_combinations.main()