def testPartitionToOne(self):
        # For small variables there is only one partition.
        variable_partitioner = sharded_variable.MinSizePartitioner(
            min_shard_bytes=64 << 20, max_shards=2)
        strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
            self.cluster_resolver, variable_partitioner)
        with strategy.scope():
            initializer = init_ops_v2.Constant([0] * 10)
            v1 = variables.Variable(initial_value=lambda: initializer(
                shape=(10, ), dtype=dtypes.int64),
                                    shape=(10, ),
                                    dtype=dtypes.int64)

            v2 = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

        self.assertIsInstance(v1, variables.Variable)
        self.assertNotIsInstance(v1, sharded_variable.ShardedVariable)
        self.assertRegex(v1.device, "/job:ps/replica:0/task:0")
        self.assertAllEqual(v1, [0] * 10)

        self.assertIsInstance(v2, variables.Variable)
        self.assertNotIsInstance(v2, sharded_variable.ShardedVariable)
        self.assertRegex(v2.device, "/job:ps/replica:0/task:1")
        self.assertAllEqual(v2, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Beispiel #2
0
            return _create_multi_worker_mirrored()
        except errors.UnknownError as e:
            if "Could not start gRPC server" in e.message and (
                    len(sys.argv) >= 1 and "bazel" in sys.argv[0]):
                raise unittest.SkipTest("Cannot start std servers.")
            else:
                raise

    return skip_if_cannot_start_grpc_server


# Due to b/195615322, FixedShardsPartitioner will wrongly partition
# RNG state, so we use MinSizePartitioner as the default. Maximum RNG
# state size is int64[3] which is 8 * 3 bytes, so we set
# min_shard_bytes to 8 * 3 + 1.
DEFAULT_PARTITIONER = sharded_variable.MinSizePartitioner(
    min_shard_bytes=8 * 3 + 1, max_shards=2)


def _get_ps_strategy_creator(num_workers,
                             num_ps,
                             required_gpus=0,
                             variable_partitioner=DEFAULT_PARTITIONER):
    def _create_ps_strategy(resolver, variable_partitioner):
        return parameter_server_strategy_v2.ParameterServerStrategyV2(
            resolver, variable_partitioner=variable_partitioner)

    def _create_parameter_server():
        if framework_test_util.is_xla_enabled():
            # To address test failures resulting in XLA with MultiProcessRunner,
            # continue to use in-process cluster for XLA tests.
            cluster_def = multi_worker_test_base.create_in_process_cluster(