Ejemplo n.º 1
0
  def testZeroItemsInClusterSpecMasterRead(self):
    os.environ['TF_CONFIG'] = """
    {}
    """

    cluster_resolver = TFConfigClusterResolver()
    self.assertEqual('', cluster_resolver.master())
Ejemplo n.º 2
0
  def testOneItemInClusterSpecMasterRead(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "worker": ["worker0:2222"]
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver()
    self.assertEqual('', cluster_resolver.master())
Ejemplo n.º 3
0
  def testAutomaticMasterRead(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver()
    self.assertEqual('ps0:2222', cluster_resolver.master())
Ejemplo n.º 4
0
  def testSpecifiedTaskTypeAndIndexMasterRead(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver()
    self.assertEqual('worker1:2222', cluster_resolver.master('worker', 1))
Ejemplo n.º 5
0
  def testParameterOverrides(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "rpc_layer": "grpc",
      "task": {
        "type": "ps",
        "index": 1
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver(task_type='ps', task_id=0)

    self.assertEqual('grpc://ps0:2222', cluster_resolver.master())
    self.assertEqual('ps', cluster_resolver.task_type)
    self.assertEqual(0, cluster_resolver.task_id)

    cluster_resolver.task_type = 'worker'
    cluster_resolver.task_id = 1
    cluster_resolver.rpc_layer = 'test'

    self.assertEqual('test://worker1:2222', cluster_resolver.master())
    self.assertEqual('worker', cluster_resolver.task_type)
    self.assertEqual(1, cluster_resolver.task_id)
    self.assertEqual('test', cluster_resolver.rpc_layer)
Ejemplo n.º 6
0
    def testNumAcceleratorsFilterTasksByEnvVar(self, mock_list_devices,
                                               mock_eager_list_devices):
        os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "worker1": ["w10:2222"],
        "worker2": ["w21:2222", "w22:2222", "w23:2222", "w24:2222"]
      },
      "rpc_layer": "grpc",
      "task": {
        "type": "worker1",
        "index": "0"
      }
    }
    """

        devices = [
            LogicalDevice('/job:worker1/task:0/device:TPU:0', 'TPU'),
            LogicalDevice('/job:worker1/task:0/device:TPU:1', 'TPU'),
            LogicalDevice('/job:worker1/task:0/device:GPU:0', 'GPU'),
            LogicalDevice('/job:worker1/task:0/device:GPU:1', 'GPU'),
            LogicalDevice('/job:worker2/task:1/device:TPU:2', 'TPU'),
            LogicalDevice('/job:worker2/task:2/device:TPU:3', 'TPU'),
            LogicalDevice('/job:worker2/task:3/device:GPU:2', 'GPU'),
            LogicalDevice('/job:worker2/task:4/device:GPU:3', 'GPU'),
        ]
        device_list = [
            session._DeviceAttributes(d.name, d.device_type, 1024, 0)
            for d in devices
        ]
        mock_eager_list_devices.return_value = devices
        mock_list_devices.return_value = device_list

        resolver = TFConfigClusterResolver()

        # By default we read from TF_CONFIG
        self.assertEqual(resolver.num_accelerators(), {'TPU': 2, 'GPU': 2})

        # Override still works when we want it to
        self.assertEqual(
            resolver.num_accelerators(task_type='worker2', task_id=3),
            {'GPU': 1})
Ejemplo n.º 7
0
  def testSparseClusterSpecRead(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": {"1": "worker1:2222"}
      },
      "task": {
        "type": "worker",
        "index": 1
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver()
    expected_proto = """
    job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
                     tasks { key: 1 value: 'ps1:2222' } }
    job { name: 'worker' tasks { key: 1 value: 'worker1:2222' } }
    """
    actual_cluster_spec = cluster_resolver.cluster_spec()
    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
Ejemplo n.º 8
0
 def testTaskIndexOverride(self):
   os.environ['TF_CONFIG'] = """
   {
     "cluster": {
       "worker": ["worker0:2222", "worker1:2222"]
     },
     "task": {
       "type": "worker",
       "index": "0"
     }
   }
   """
   cluster_resolver = TFConfigClusterResolver(task_id=1)
   self.assertEqual(1, cluster_resolver.task_id)
Ejemplo n.º 9
0
 def testTaskIndexCastToInteger(self):
   os.environ['TF_CONFIG'] = """
   {
     "cluster": {
       "ps": ["ps0:2222", "ps1:2222"],
       "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
     },
     "rpc_layer": "grpc",
     "task": {
       "type": "ps",
       "index": "1"
     }
   }
   """
   cluster_resolver = TFConfigClusterResolver()
   self.assertEqual(1, cluster_resolver.task_id)
Ejemplo n.º 10
0
 def testTaskTypeCastToString(self):
   os.environ['TF_CONFIG'] = """
   {
     "cluster": {
       "123456": ["ps0:2222", "ps1:2222"],
       "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
     },
     "rpc_layer": "grpc",
     "task": {
       "type": 123456,
       "index": 0
     }
   }
   """
   cluster_resolver = TFConfigClusterResolver()
   self.assertEqual('123456', cluster_resolver.task_type)
Ejemplo n.º 11
0
  def testTaskTypeIndexRpcRead(self):
    os.environ['TF_CONFIG'] = """
    {
      "cluster": {
        "ps": ["ps0:2222", "ps1:2222"],
        "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
      },
      "rpc_layer": "grpc",
      "task": {
        "type": "ps",
        "index": 0
      }
    }
    """

    cluster_resolver = TFConfigClusterResolver()
    self.assertEqual('ps', cluster_resolver.task_type)
    self.assertEqual(0, cluster_resolver.task_id)
    self.assertEqual('grpc', cluster_resolver.rpc_layer)