def testSimpleMasterRetrieval(self):
    slurm_cluster_resolver = SlurmClusterResolver(
        jobs={
            'ps': 1,
            'worker': 2
        },
        port_base=8888,
        tasks_per_node=1,
        gpus_per_node=1,
        gpus_per_task=1,
        auto_set_gpu=False)

    slurm_cluster_resolver.task_type = 'worker'
    slurm_cluster_resolver.task_id = 1
    self.assertEqual(slurm_cluster_resolver.master(), 'grpc://t02n43:8888')

    slurm_cluster_resolver.rpc_layer = 'ab'
    self.assertEqual(slurm_cluster_resolver.master('ps', 0), 'ab://t02n13:8888')
    self.assertEqual(
        slurm_cluster_resolver.master('ps', 0, rpc_layer='test'),
        'test://t02n13:8888')
    def testSimpleMasterRetrieval(self):
        slurm_cluster_resolver = SlurmClusterResolver(jobs={
            'ps': 1,
            'worker': 2
        },
                                                      port_base=8888,
                                                      tasks_per_node=1,
                                                      gpus_per_node=1,
                                                      gpus_per_task=1,
                                                      auto_set_gpu=False)

        slurm_cluster_resolver.task_type = 'worker'
        slurm_cluster_resolver.task_id = 1
        self.assertEqual(slurm_cluster_resolver.master(), 'grpc://t02n43:8888')

        slurm_cluster_resolver.rpc_layer = 'ab'
        self.assertEqual(slurm_cluster_resolver.master('ps', 0),
                         'ab://t02n13:8888')
        self.assertEqual(
            slurm_cluster_resolver.master('ps', 0, rpc_layer='test'),
            'test://t02n13:8888')
    def testSimpleRetrievalFromEnv(self):
        slurm_cluster_resolver = SlurmClusterResolver()

        actual_cluster_spec = slurm_cluster_resolver.cluster_spec()
        expected_proto = """
    job { name: 'worker' tasks { key: 0 value: 't02n13:8888' }
                         tasks { key: 1 value: 't02n41:8888' }
                         tasks { key: 2 value: 't02n43:8888' } }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
        self.assertEqual(
            slurm_cluster_resolver.master('worker', 0, rpc_layer='grpc'),
            'grpc://t02n13:8888')
        self.assertEqual(slurm_cluster_resolver.num_accelerators(), {'GPU': 1})
        self.assertEqual(os.environ['CUDA_VISIBLE_DEVICES'], '0')