def testEagerCustomTrainingUnimplementedError(self): cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=2) cluster_resolver = SimpleClusterResolver( cluster_spec=multi_worker_util.normalize_cluster_spec( cluster_spec), task_type='worker', task_id=1, num_accelerators={'GPU': 0}) strategy = parameter_server_strategy.ParameterServerStrategyV1( cluster_resolver) dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.]) def train_step(data): return math_ops.square(data) self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', strategy.experimental_distribute_dataset, dataset.batch(2)) self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', strategy.distribute_datasets_from_function, lambda _: dataset) self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', strategy.scope) self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', strategy.run, train_step)
def create_test_objects(cluster_spec=None, task_type=None, task_id=None, num_gpus=None, sess_config=None): sess_config = sess_config or config_pb2.ConfigProto() if num_gpus is None: num_gpus = context.num_gpus() if cluster_spec and task_type and task_id is not None: cluster_resolver = SimpleClusterResolver( cluster_spec=multi_worker_util.normalize_cluster_spec( cluster_spec), task_type=task_type, task_id=task_id, num_accelerators={'GPU': num_gpus}) distribution = parameter_server_strategy.ParameterServerStrategyV1( cluster_resolver) target = 'grpc://' + cluster_spec[WORKER][task_id] else: distribution = (central_storage_strategy.CentralStorageStrategy. _from_num_gpus(num_gpus)) target = '' sess_config = copy.deepcopy(sess_config) sess_config = distribution.update_config_proto(sess_config) return distribution, target, sess_config