def testParameterOverrides(self): os.environ['TF_CONFIG'] = """ { "cluster": { "ps": ["ps0:2222", "ps1:2222"], "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] }, "rpc_layer": "grpc", "task": { "type": "ps", "index": 1 } } """ cluster_resolver = TFConfigClusterResolver(task_type='ps', task_index=0, num_accelerators=8) self.assertEqual('grpc://ps0:2222', cluster_resolver.master()) self.assertEqual('ps', cluster_resolver.task_type) self.assertEqual(0, cluster_resolver.task_index) self.assertEqual(8, cluster_resolver.num_accelerators()) cluster_resolver.task_type = 'worker' cluster_resolver.task_index = 1 cluster_resolver.rpc_layer = 'test' self.assertEqual('test://worker1:2222', cluster_resolver.master()) self.assertEqual('worker', cluster_resolver.task_type) self.assertEqual(1, cluster_resolver.task_index) self.assertEqual('test', cluster_resolver.rpc_layer)
def testNumAcceleratorsFilterTasksByEnvVar(self, mock_list_devices, mock_eager_list_devices): os.environ['TF_CONFIG'] = """ { "cluster": { "worker1": ["w10:2222"], "worker2": ["w21:2222", "w22:2222", "w23:2222", "w24:2222"] }, "rpc_layer": "grpc", "task": { "type": "worker1", "index": "0" } } """ devices = [ LogicalDevice('/job:worker1/task:0/device:TPU:0', 'TPU'), LogicalDevice('/job:worker1/task:0/device:TPU:1', 'TPU'), LogicalDevice('/job:worker1/task:0/device:GPU:0', 'GPU'), LogicalDevice('/job:worker1/task:0/device:GPU:1', 'GPU'), LogicalDevice('/job:worker2/task:1/device:TPU:2', 'TPU'), LogicalDevice('/job:worker2/task:2/device:TPU:3', 'TPU'), LogicalDevice('/job:worker2/task:3/device:GPU:2', 'GPU'), LogicalDevice('/job:worker2/task:4/device:GPU:3', 'GPU'), ] device_list = [ session._DeviceAttributes(d.name, d.device_type, 1024, 0) for d in devices ] mock_eager_list_devices.return_value = devices mock_list_devices.return_value = device_list resolver = TFConfigClusterResolver() # By default we read from TF_CONFIG self.assertEqual(resolver.num_accelerators(), {'TPU': 2, 'GPU': 2}) # Override still works when we want it to self.assertEqual( resolver.num_accelerators(task_type='worker2', task_id=3), {'GPU': 1})