def testSimpleRetrievalFromEnv(self):
        slurm_cluster_resolver = SlurmClusterResolver()

        actual_cluster_spec = slurm_cluster_resolver.cluster_spec()
        expected_proto = """
    job { name: 'worker' tasks { key: 0 value: 't02n13:8888' }
                         tasks { key: 1 value: 't02n41:8888' }
                         tasks { key: 2 value: 't02n43:8888' } }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
        self.assertEqual(
            slurm_cluster_resolver.master('worker', 0, rpc_layer='grpc'),
            'grpc://t02n13:8888')
        self.assertEqual(slurm_cluster_resolver.num_accelerators(), {'GPU': 1})
        self.assertEqual(os.environ['CUDA_VISIBLE_DEVICES'], '0')
    def testTaskPerNodeNotSetRetrieval(self):
        slurm_cluster_resolver = SlurmClusterResolver(jobs={
            'ps': 1,
            'worker': 2
        },
                                                      port_base=8888,
                                                      gpus_per_node=1,
                                                      gpus_per_task=1,
                                                      auto_set_gpu=False)

        actual_cluster_spec = slurm_cluster_resolver.cluster_spec()
        expected_proto = """
    job { name: 'ps' tasks { value: 't02n13:8888' } }
    job { name: 'worker' tasks { key: 0 value: 't02n41:8888' }
                         tasks { key: 1 value: 't02n43:8888' } }
    """
        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
  def testTaskPerNodeNotSetRetrieval(self):
    slurm_cluster_resolver = SlurmClusterResolver(
        jobs={
            'ps': 1,
            'worker': 2
        },
        port_base=8888,
        gpus_per_node=1,
        gpus_per_task=1,
        auto_set_gpu=False)

    actual_cluster_spec = slurm_cluster_resolver.cluster_spec()
    expected_proto = """
    job { name: 'ps' tasks { value: 't02n13:8888' } }
    job { name: 'worker' tasks { key: 0 value: 't02n41:8888' }
                         tasks { key: 1 value: 't02n43:8888' } }
    """
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
    def testMultipleGpusPerTaskRetrieval(self):
        slurm_cluster_resolver = SlurmClusterResolver(jobs={
            'ps': 1,
            'worker': 4
        },
                                                      port_base=8888,
                                                      gpus_per_node=4,
                                                      gpus_per_task=2,
                                                      auto_set_gpu=True)

        actual_cluster_spec = slurm_cluster_resolver.cluster_spec()
        expected_proto = """
    job { name: 'ps' tasks { value: 't02n13:8888' } }
    job { name: 'worker' tasks { key: 0 value: 't02n13:8889' }
                         tasks { key: 1 value: 't02n41:8888' }
                         tasks { key: 2 value: 't02n41:8889' }
                         tasks { key: 3 value: 't02n43:8888' } }
    """

        self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
        assert os.environ['CUDA_VISIBLE_DEVICES'] == '2,3'
  def testMultipleGpusPerTaskRetrieval(self):
    slurm_cluster_resolver = SlurmClusterResolver(
        jobs={
            'ps': 1,
            'worker': 4
        },
        port_base=8888,
        gpus_per_node=4,
        gpus_per_task=2,
        auto_set_gpu=True)

    actual_cluster_spec = slurm_cluster_resolver.cluster_spec()
    expected_proto = """
    job { name: 'ps' tasks { value: 't02n13:8888' } }
    job { name: 'worker' tasks { key: 0 value: 't02n13:8889' }
                         tasks { key: 1 value: 't02n41:8888' }
                         tasks { key: 2 value: 't02n41:8889' }
                         tasks { key: 3 value: 't02n43:8888' } }
    """

    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
    assert os.environ['CUDA_VISIBLE_DEVICES'] == '2,3'
    def testSimpleMasterRetrieval(self):
        slurm_cluster_resolver = SlurmClusterResolver(jobs={
            'ps': 1,
            'worker': 2
        },
                                                      port_base=8888,
                                                      tasks_per_node=1,
                                                      gpus_per_node=1,
                                                      gpus_per_task=1,
                                                      auto_set_gpu=False)

        slurm_cluster_resolver.task_type = 'worker'
        slurm_cluster_resolver.task_id = 1
        self.assertEqual(slurm_cluster_resolver.master(), 'grpc://t02n43:8888')

        slurm_cluster_resolver.rpc_layer = 'ab'
        self.assertEqual(slurm_cluster_resolver.master('ps', 0),
                         'ab://t02n13:8888')
        self.assertEqual(
            slurm_cluster_resolver.master('ps', 0, rpc_layer='test'),
            'test://t02n13:8888')
  def testSimpleMasterRetrieval(self):
    slurm_cluster_resolver = SlurmClusterResolver(
        jobs={
            'ps': 1,
            'worker': 2
        },
        port_base=8888,
        tasks_per_node=1,
        gpus_per_node=1,
        gpus_per_task=1,
        auto_set_gpu=False)

    slurm_cluster_resolver.task_type = 'worker'
    slurm_cluster_resolver.task_id = 1
    self.assertEqual(slurm_cluster_resolver.master(), 'grpc://t02n43:8888')

    slurm_cluster_resolver.rpc_layer = 'ab'
    self.assertEqual(slurm_cluster_resolver.master('ps', 0), 'ab://t02n13:8888')
    self.assertEqual(
        slurm_cluster_resolver.master('ps', 0, rpc_layer='test'),
        'test://t02n13:8888')