def testParameterOverrides(self):
        os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "rpc_layer": "grpc",
      "task": {
        "type": "ps",
        "index": 1
      }
    }
    """

        cluster_resolver = TFConfigClusterResolver(task_type='ps',
                                                   task_index=0,
                                                   num_accelerators=8)

        self.assertEqual('grpc://ps0:2222', cluster_resolver.master())
        self.assertEqual('ps', cluster_resolver.task_type)
        self.assertEqual(0, cluster_resolver.task_index)
        self.assertEqual(8, cluster_resolver.num_accelerators())

        cluster_resolver.task_type = 'worker'
        cluster_resolver.task_index = 1
        cluster_resolver.rpc_layer = 'test'

        self.assertEqual('test://worker1:2222', cluster_resolver.master())
        self.assertEqual('worker', cluster_resolver.task_type)
        self.assertEqual(1, cluster_resolver.task_index)
        self.assertEqual('test', cluster_resolver.rpc_layer)
    def testNumAcceleratorsFilterTasksByEnvVar(self, mock_list_devices,
                                               mock_eager_list_devices):
        os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "worker1": ["w10:2222"],
        "worker2": ["w21:2222", "w22:2222", "w23:2222", "w24:2222"]
      },
      "rpc_layer": "grpc",
      "task": {
        "type": "worker1",
        "index": "0"
      }
    }
    """

        devices = [
            LogicalDevice('/job:worker1/task:0/device:TPU:0', 'TPU'),
            LogicalDevice('/job:worker1/task:0/device:TPU:1', 'TPU'),
            LogicalDevice('/job:worker1/task:0/device:GPU:0', 'GPU'),
            LogicalDevice('/job:worker1/task:0/device:GPU:1', 'GPU'),
            LogicalDevice('/job:worker2/task:1/device:TPU:2', 'TPU'),
            LogicalDevice('/job:worker2/task:2/device:TPU:3', 'TPU'),
            LogicalDevice('/job:worker2/task:3/device:GPU:2', 'GPU'),
            LogicalDevice('/job:worker2/task:4/device:GPU:3', 'GPU'),
        ]
        device_list = [
            session._DeviceAttributes(d.name, d.device_type, 1024, 0)
            for d in devices
        ]
        mock_eager_list_devices.return_value = devices
        mock_list_devices.return_value = device_list

        resolver = TFConfigClusterResolver()

        # By default we read from TF_CONFIG
        self.assertEqual(resolver.num_accelerators(), {'TPU': 2, 'GPU': 2})

        # Override still works when we want it to
        self.assertEqual(
            resolver.num_accelerators(task_type='worker2', task_id=3),
            {'GPU': 1})