예제 #1
0
    def testEagerCustomTrainingUnimplementedError(self):
        cluster_spec = multi_worker_test_base.create_in_process_cluster(
            num_workers=3, num_ps=2)
        cluster_resolver = SimpleClusterResolver(
            cluster_spec=multi_worker_util.normalize_cluster_spec(
                cluster_spec),
            task_type='worker',
            task_id=1,
            num_accelerators={'GPU': 0})
        strategy = parameter_server_strategy.ParameterServerStrategyV1(
            cluster_resolver)
        dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.])

        def train_step(data):
            return math_ops.square(data)

        self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
                               strategy.experimental_distribute_dataset,
                               dataset.batch(2))

        self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
                               strategy.distribute_datasets_from_function,
                               lambda _: dataset)

        self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
                               strategy.scope)

        self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
                               strategy.run, train_step)
예제 #2
0
def create_test_objects(cluster_spec=None,
                        task_type=None,
                        task_id=None,
                        num_gpus=None,
                        sess_config=None):
    sess_config = sess_config or config_pb2.ConfigProto()
    if num_gpus is None:
        num_gpus = context.num_gpus()
    if cluster_spec and task_type and task_id is not None:
        cluster_resolver = SimpleClusterResolver(
            cluster_spec=multi_worker_util.normalize_cluster_spec(
                cluster_spec),
            task_type=task_type,
            task_id=task_id,
            num_accelerators={'GPU': num_gpus})
        distribution = parameter_server_strategy.ParameterServerStrategyV1(
            cluster_resolver)
        target = 'grpc://' + cluster_spec[WORKER][task_id]
    else:
        distribution = (central_storage_strategy.CentralStorageStrategy.
                        _from_num_gpus(num_gpus))
        target = ''

    sess_config = copy.deepcopy(sess_config)
    sess_config = distribution.update_config_proto(sess_config)

    return distribution, target, sess_config