コード例 #1
0
 def testModelSplit(self):
     p = cluster_factory.Cluster.Params()
     with p.Instantiate() as c:
         with cluster_factory.SetModelSplit(2) as c1:
             self.assertEqual(c1.params.split_id, 2)
             with cluster_factory.SetModelSplit(3) as c2:
                 self.assertEqual(c2.params.split_id, 3)
         self.assertEqual(c.params.split_id, 0)
コード例 #2
0
ファイル: cluster_test.py プロジェクト: ruby11dog/lingvo
 def testWorkerDeviceInModelSplitSync(self):
   p = cluster_factory.Cluster.Params()
   p.mode = 'sync'
   p.job = 'trainer_client'
   p.worker.name = '/job:trainer'
   p.worker.replicas = 4
   p.worker.gpus_per_replica = 4
   p.worker.devices_per_split = 2
   with cluster_factory.Cluster(p):
     with cluster_factory.SetModelSplit(1) as c:
       d = c.WorkerDeviceInModelSplit(1)
       expected_device = c._MakeDeviceString(
           job_name='/job:trainer', task_id=0, device_name='GPU', device_id=3)
   self.assertEqual(expected_device, d)
コード例 #3
0
ファイル: base_model.py プロジェクト: ai-learn-use/lingvo
  def _FPropSplitInputBatch(self, theta, input_batch):
    """Splits the input batch on the input device."""
    if py_utils.use_tpu():
      return self._FPropTpu(theta, input_batch)

    cluster = self.cluster
    num_splits = cluster.num_splits_per_client

    if not isinstance(input_batch, list):
      input_batch = [input_batch]

    assert len(input_batch) == num_splits, (len(input_batch), num_splits)

    # dev_list_per_replica[i][j] is the i-th worker's j-th device.
    dev_list_per_replica = cluster.available_devices.tolist()

    # Asserts invariant of the total number of splits w.r.t.,
    # splits per worker.
    splits_per_replica = cluster.num_splits_per_replica
    assert num_splits == splits_per_replica * len(dev_list_per_replica), (
        num_splits, splits_per_replica, len(dev_list_per_replica))

    all_metrics = []
    all_per_example_tensors = []
    with cluster:
      for w_id, w_devs in enumerate(dev_list_per_replica):
        # Make local copy of the vars, shard on devices for this worker.
        theta_local = py_utils.CreateLocalTheta(
            theta, w_devs, label='worker %d' % w_id)

        for s_id in range(splits_per_replica):
          # s_id-th split for the w_id-th worker.
          split_id = splits_per_replica * w_id + s_id
          with cluster_factory.SetModelSplit(split_id) as c:
            with tf.device(c.WorkerDeviceInModelSplit(0)):
              with tf.name_scope('tower_%d_%d' % (w_id, s_id)):
                batch = input_batch[split_id]
                metrics, per_example = self.FPropTower(theta_local, batch)
          all_metrics.append(metrics)
          all_per_example_tensors.append(per_example)
    return py_utils.WeightedAvgOfMetrics(
        all_metrics), py_utils.ConcatPerExampleTensors(all_per_example_tensors)