Exemplo n.º 1
0
    def test1(self):
        @def_function.function
        def f():
            a = array_ops.identity(1., name='a')
            b = a + 1
            c = array_ops.identity(2., name='c')
            d = array_ops.identity(a + c, name='d')
            with ops.control_dependencies([b]):
                e = array_ops.identity(3., name='e')
            f = array_ops.identity(c + e, name='f')
            return d, f

        graph = f.get_concrete_function().graph
        order = test_util.topological_sort_operations(graph.get_operations())
        a = graph.get_operation_by_name('a')
        c = graph.get_operation_by_name('c')
        d = graph.get_operation_by_name('d')
        e = graph.get_operation_by_name('e')
        f = graph.get_operation_by_name('f')
        test_util.assert_sequential_execution(order, [a, d])
        test_util.assert_sequential_execution(order, [e, a, f])
        with self.assertRaises(AssertionError):
            test_util.assert_sequential_execution(order, [a, c])
        with self.assertRaises(AssertionError):
            test_util.assert_sequential_execution(order, [f, a, c])
        with self.assertRaises(AssertionError):
            test_util.assert_sequential_execution(order, [d, e, a, c])
        def replica_fn():
            cross_device_ops_lib.CollectiveAllReduce._limited_nccl = False
            cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = True
            cross_device_utils.CollectiveReplicaLauncher._use_ordering_token = True
            collective, devices, _ = self.make_collective(
                num_processes, required_gpus)
            options = collective_util.Options(
                implementation=CommunicationImplementation.NCCL)

            v_dense = make_per_replica_value([1.0, 1.0], devices)
            v_sparse = make_per_replica_value([
                IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
                IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
            ], devices)

            @def_function.function
            def nested_dense():
                collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense,
                                  options)

            @def_function.function
            def nested_sparse():
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)

            # All collectives, function calls, if clause and while loops should be
            # chained by control dependencies, so that the execution order is
            # deterministic.
            @def_function.function
            def f():
                # pylint: disable=pointless-statement
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)
                # reducing dense value.
                collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense,
                                  options)
                # reducing sparse value.
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)
                # reduce dense value in nested tf.function.
                nested_dense()
                # reduce sparse value in nested tf.function.
                nested_sparse()
                # reduce dense value in tf.cond.
                if array_ops.identity(1.0) > array_ops.identity(2.0):
                    collective.reduce(reduce_util.ReduceOp.SUM, v_dense,
                                      v_dense, options)
                else:
                    v_dense
                # reduce sparse value in tf.cond.
                if array_ops.identity(1.0) > array_ops.identity(2.0):
                    v_sparse
                else:
                    collective.reduce(reduce_util.ReduceOp.SUM, v_sparse,
                                      v_sparse, options)
                # reduce dense value in tf.while_loop.
                i = array_ops.identity(1)
                while i < 3:
                    collective.reduce(reduce_util.ReduceOp.SUM, v_dense,
                                      v_dense, options)
                    i += 1
                # reduce sparse value in tf.while_loop.
                i = array_ops.identity(1)
                while i < 3:
                    collective.reduce(reduce_util.ReduceOp.SUM, v_sparse,
                                      v_sparse, options)
                    i += 1
                # reducing dense and sparse value again.
                collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense,
                                  options)
                collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
                                  options)
                # pylint: enable=pointless-statement

            graph = f.get_concrete_function().graph
            should_be_ordered = set([
                "CollectiveReduceV2", "CollectiveGatherV2", "If", "While",
                "StatefulPartitionedCall"
            ])
            nodes_by_device = {}
            for op in graph.get_operations():
                if op.type in should_be_ordered:
                    if op.device not in nodes_by_device:
                        nodes_by_device[op.device] = []
                    nodes_by_device[op.device].append(op)
            order = test_util.topological_sort_operations(
                graph.get_operations())
            for device in devices:
                device = device_util.canonicalize(device)
                # Those function ops don't have device annotations, but they contain
                # collectives for both devices so we always include them.
                operations = nodes_by_device[device] + nodes_by_device[""]
                # Verify that we get all types of nodes we want.
                self.assertEqual(set(op.type for op in operations),
                                 should_be_ordered)
                test_util.assert_sequential_execution(order, operations)