Ejemplo n.º 1
0
 def testWorkerDeviceInModelSplitSync(self):
   p = cluster_factory.Cluster.Params()
   p.mode = 'sync'
   p.job = 'trainer_client'
   p.worker.name = '/job:trainer'
   p.worker.replicas = 4
   p.worker.gpus_per_replica = 4
   p.worker.devices_per_split = 2
   c = cluster_factory.Cluster(p)
   with py_utils.ModelSplit(1):
     d = c.WorkerDeviceInModelSplit(1)
   expected_device = c._MakeDeviceString(
       job_name='/job:trainer', task_id=0, device_name='GPU', device_id=3)
   self.assertEqual(expected_device, d)
Ejemplo n.º 2
0
  def _FPropSplitInputBatch(self, theta, input_batch):
    """Splits the input batch on the input device."""
    cluster = self.cluster
    num_splits = cluster.num_splits_per_client

    if not isinstance(input_batch, list):
      input_batch = [input_batch]

    assert len(input_batch) == num_splits, (len(input_batch), num_splits)

    # dev_list_per_replica[i][j] is the i-th worker's j-th device.
    dev_list_per_replica = cluster.available_devices.tolist()

    # Asserts invariant of the total number of splits w.r.t.,
    # splits per worker.
    splits_per_replica = cluster.num_splits_per_replica
    assert num_splits == splits_per_replica * len(dev_list_per_replica), (
        num_splits, splits_per_replica, len(dev_list_per_replica))

    all_metrics = []
    all_per_example_tensors = []
    for w_id, w_devs in enumerate(dev_list_per_replica):
      # Make local copy of the vars, shard on devices for this worker.
      theta_local = py_utils.CreateLocalTheta(
          theta, w_devs, label='worker %d' % w_id)

      for s_id in range(splits_per_replica):
        # s_id-th split for the w_id-th worker.
        split_id = splits_per_replica * w_id + s_id
        with py_utils.ModelSplit(split_id):
          with tf.device(cluster.WorkerDeviceInModelSplit(0)):
            with tf.name_scope('tower_%d_%d' % (w_id, s_id)):
              batch = self.input_generator.PreprocessInputBatch(
                  input_batch[split_id])
              metrics, per_example = self.FPropTower(theta_local, batch)
        all_metrics.append(metrics)
        all_per_example_tensors.append(per_example)

    return py_utils.WeightedAvgOfMetrics(
        all_metrics), py_utils.ConcatPerExampleTensors(all_per_example_tensors)
Ejemplo n.º 3
0
 def testModelSplit(self):
   with py_utils.ModelSplit(2):
     assert py_utils.GetModelSplit() == 2
     with py_utils.ModelSplit(3):
       assert py_utils.GetModelSplit() == 3
   assert py_utils.GetModelSplit() == 0
Ejemplo n.º 4
0
    def FProp(self, theta):
        """Forward propagation.

    This default `FProp` implementation here supports batch splitting in
    synchronous and asynchronous training when sub-classes implement
    `FPropTower`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.

    Returns:
      A dict containing metrics pairs. One of the keys should be 'loss' and its
      value should be a (loss, num_predictions) pair.
    """
        p = self.params
        cluster = cluster_factory.Current()

        with tf.name_scope('fprop'), tf.name_scope(p.name):
            all_fprop_metrics = []

            if py_utils.use_tpu():
                batch = self.input_generator.CreateTpuFeeds()
                with tf.name_scope('tower_0_0'):
                    dec_metrics = self.FPropTower(theta, batch)
                all_fprop_metrics.append(dec_metrics)
            else:
                # Splits the input batch on the input device.
                num_splits = cluster.num_splits_per_client
                with tf.device(cluster.input_device):
                    batches = self.input_generator.SplitInputBatch(num_splits)
                    assert num_splits == len(batches)

                # dev_list_per_replica[i][j] is the i-th worker's j-th device.
                dev_list_per_replica = cluster.available_devices.tolist()

                # Asserts invariant of the total number of splits w.r.t.,
                # splits per worker.
                splits_per_replica = cluster.num_splits_per_replica
                assert num_splits == splits_per_replica * len(
                    dev_list_per_replica)

                for w_id, w_devs in enumerate(dev_list_per_replica):
                    # Make local copy of the vars, shard on devices for this worker.
                    theta_local = py_utils.CreateLocalTheta(theta,
                                                            w_devs,
                                                            label='worker %d' %
                                                            w_id)

                    for s_id in range(splits_per_replica):
                        # s_id-th split for the w_id-th worker.
                        split_id = splits_per_replica * w_id + s_id
                        with py_utils.ModelSplit(split_id):
                            with tf.device(
                                    cluster.WorkerDeviceInModelSplit(0)):
                                with tf.name_scope('tower_%d_%d' %
                                                   (w_id, s_id)):
                                    batch = self.input_generator.PreprocessInputBatch(
                                        batches[split_id])
                                    dec_metrics = self.FPropTower(
                                        theta_local, batch)
                        all_fprop_metrics.append(dec_metrics)

            metrics = py_utils.WeightedAvgOfMetrics(all_fprop_metrics)

        # Adds stats about the input batch.
        metrics['num_samples_in_batch'] = (tf.convert_to_tensor(
            self.input_generator.InputBatchSize()), tf.constant(1.0))
        # Generates summaries.
        for name, (value, weight) in six.iteritems(metrics):
            self.AddEvalMetric(name, value, weight)

        # Loss.
        self._loss, self._num_predicts = metrics['loss']
        self._loss = py_utils.CheckNumerics(self._loss)

        return metrics