コード例 #1
0
    def testPartitionedGraphUsesDefaultExecutor(self):
        if not context.num_gpus():
            self.skipTest("GPU-only test (requires partitioned graph).")

        default_executor = test_util.TestDelta("flr_executor", "default")
        single_threaded = test_util.TestDelta("flr_executor",
                                              "single_threaded")
        array_ops.fill([2], constant_op.constant(7, dtype=dtypes.int64))
        assert default_executor.Get() > 0
        assert single_threaded.Get() == 0
コード例 #2
0
    def testSimpleGraphExecutesSynchronously(self):
        if context.num_gpus():
            self.skipTest("CPU-only test (requires unpartitioned graph).")

        default_executor = test_util.TestDelta("flr_executor", "default")
        single_threaded = test_util.TestDelta("flr_executor",
                                              "single_threaded")
        run_async = test_util.TestDelta("pflr_runsync", "async")
        run_sync = test_util.TestDelta("pflr_runsync", "sync")
        safe = test_util.TestDelta("subgraph_async_summary", "safe_for_sync")

        array_ops.fill([2], constant_op.constant(7, dtype=dtypes.int64))

        assert default_executor.Get() == 0
        assert single_threaded.Get() > 0
        assert run_async.Get() == 0
        assert run_sync.Get() > 0
        assert safe.Get() > 0
コード例 #3
0
    def testSendRecvPartitionedGraphExecutesSynchronously(self):
        if not context.num_gpus():
            self.skipTest("GPU-only test (requires partitioned graph).")

        default_executor = test_util.TestDelta("flr_executor", "default")
        single_threaded = test_util.TestDelta("flr_executor",
                                              "single_threaded")
        run_async = test_util.TestDelta("pflr_runsync", "async")
        run_sync = test_util.TestDelta("pflr_runsync", "sync")
        send_only = test_util.TestDelta("subgraph_async_summary", "send_only")
        recv_only = test_util.TestDelta("subgraph_async_summary", "recv_only")

        array_ops.fill([2], constant_op.constant(7, dtype=dtypes.int64))

        assert default_executor.Get() == 0
        assert single_threaded.Get() > 0
        assert run_async.Get() == 0
        assert run_sync.Get() > 0
        assert send_only.Get() > 0
        assert recv_only.Get() > 0