示例#1
0
 def _FPropTpu(self, theta, input_batch):
   p = self.params
   with tf.name_scope('fprop'), tf.name_scope(p.name):
     with tf.name_scope('tower_0_0'):
       metrics = self.FPropTower(theta, input_batch)
       metrics = py_utils.WeightedAvgOfMetrics([metrics])
   return metrics
示例#2
0
 def testWeightedAvgOfMetrics(self):
   with self.session(use_gpu=False) as sess:
     metrics = [{
         'a': (2.0, 0.5),
         'b': (5.0, 1.5)
     }, {
         'a': (9.0, 3.0),
         'b': (4.0, 0.5)
     }]
     expected = {'a': (8.0, 3.5), 'b': (4.75, 2.0)}
     weighted_avg = py_utils.WeightedAvgOfMetrics(metrics)
     actual = sess.run(weighted_avg)
     self.assertDictEqual(actual, expected)
示例#3
0
  def _FPropSplitInputBatch(self, theta, input_batch):
    """Splits the input batch on the input device."""
    if py_utils.use_tpu():
      return self._FPropTpu(theta, input_batch)

    cluster = self.cluster
    num_splits = cluster.num_splits_per_client

    if not isinstance(input_batch, list):
      input_batch = [input_batch]

    assert len(input_batch) == num_splits, (len(input_batch), num_splits)

    # dev_list_per_replica[i][j] is the i-th worker's j-th device.
    dev_list_per_replica = cluster.available_devices.tolist()

    # Asserts invariant of the total number of splits w.r.t.,
    # splits per worker.
    splits_per_replica = cluster.num_splits_per_replica
    assert num_splits == splits_per_replica * len(dev_list_per_replica), (
        num_splits, splits_per_replica, len(dev_list_per_replica))

    all_metrics = []
    all_per_example_tensors = []
    with cluster:
      for w_id, w_devs in enumerate(dev_list_per_replica):
        # Make local copy of the vars, shard on devices for this worker.
        theta_local = py_utils.CreateLocalTheta(
            theta, w_devs, label='worker %d' % w_id)

        for s_id in range(splits_per_replica):
          # s_id-th split for the w_id-th worker.
          split_id = splits_per_replica * w_id + s_id
          with cluster_factory.SetModelSplit(split_id) as c:
            with tf.device(c.WorkerDeviceInModelSplit(0)):
              with tf.name_scope('tower_%d_%d' % (w_id, s_id)):
                batch = input_batch[split_id]
                metrics, per_example = self.FPropTower(theta_local, batch)
          all_metrics.append(metrics)
          all_per_example_tensors.append(per_example)
    return py_utils.WeightedAvgOfMetrics(
        all_metrics), py_utils.ConcatPerExampleTensors(all_per_example_tensors)
示例#4
0
    def _FPropSplitInputBatch(self, theta, input_batch):
        """Splits the input batch on the input device."""
        cluster = self.cluster
        num_splits = cluster.num_splits_per_client

        if not isinstance(input_batch, list):
            input_batch = [input_batch]

        assert len(input_batch) == num_splits, (len(input_batch), num_splits)

        # dev_list_per_replica[i][j] is the i-th worker's j-th device.
        dev_list_per_replica = cluster.available_devices.tolist()

        # Asserts invariant of the total number of splits w.r.t.,
        # splits per worker.
        splits_per_replica = cluster.num_splits_per_replica
        assert num_splits == splits_per_replica * len(dev_list_per_replica), (
            num_splits, splits_per_replica, len(dev_list_per_replica))

        all_fprop_metrics = []
        for w_id, w_devs in enumerate(dev_list_per_replica):
            # Make local copy of the vars, shard on devices for this worker.
            theta_local = py_utils.CreateLocalTheta(theta,
                                                    w_devs,
                                                    label='worker %d' % w_id)

            for s_id in range(splits_per_replica):
                # s_id-th split for the w_id-th worker.
                split_id = splits_per_replica * w_id + s_id
                with py_utils.ModelSplit(split_id):
                    with tf.device(cluster.WorkerDeviceInModelSplit(0)):
                        with tf.name_scope('tower_%d_%d' % (w_id, s_id)):
                            batch = self.input_generator.PreprocessInputBatch(
                                input_batch[split_id])
                            dec_metrics = self.FPropTower(theta_local, batch)
                all_fprop_metrics.append(dec_metrics)

        return py_utils.WeightedAvgOfMetrics(all_fprop_metrics)
示例#5
0
    def FProp(self, theta):
        """Forward propagation.

    This default `FProp` implementation here supports batch splitting in
    synchronous and asynchronous training when sub-classes implement
    `FPropTower`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.

    Returns:
      A dict containing metrics pairs. One of the keys should be 'loss' and its
      value should be a (loss, num_predictions) pair.
    """
        p = self.params
        cluster = cluster_factory.Current()

        with tf.name_scope('fprop'), tf.name_scope(p.name):
            all_fprop_metrics = []

            if py_utils.use_tpu():
                batch = self.input_generator.CreateTpuFeeds()
                with tf.name_scope('tower_0_0'):
                    dec_metrics = self.FPropTower(theta, batch)
                all_fprop_metrics.append(dec_metrics)
            else:
                # Splits the input batch on the input device.
                num_splits = cluster.num_splits_per_client
                with tf.device(cluster.input_device):
                    batches = self.input_generator.SplitInputBatch(num_splits)
                    assert num_splits == len(batches)

                # dev_list_per_replica[i][j] is the i-th worker's j-th device.
                dev_list_per_replica = cluster.available_devices.tolist()

                # Asserts invariant of the total number of splits w.r.t.,
                # splits per worker.
                splits_per_replica = cluster.num_splits_per_replica
                assert num_splits == splits_per_replica * len(
                    dev_list_per_replica)

                for w_id, w_devs in enumerate(dev_list_per_replica):
                    # Make local copy of the vars, shard on devices for this worker.
                    theta_local = py_utils.CreateLocalTheta(theta,
                                                            w_devs,
                                                            label='worker %d' %
                                                            w_id)

                    for s_id in range(splits_per_replica):
                        # s_id-th split for the w_id-th worker.
                        split_id = splits_per_replica * w_id + s_id
                        with py_utils.ModelSplit(split_id):
                            with tf.device(
                                    cluster.WorkerDeviceInModelSplit(0)):
                                with tf.name_scope('tower_%d_%d' %
                                                   (w_id, s_id)):
                                    batch = self.input_generator.PreprocessInputBatch(
                                        batches[split_id])
                                    dec_metrics = self.FPropTower(
                                        theta_local, batch)
                        all_fprop_metrics.append(dec_metrics)

            metrics = py_utils.WeightedAvgOfMetrics(all_fprop_metrics)

        # Adds stats about the input batch.
        metrics['num_samples_in_batch'] = (tf.convert_to_tensor(
            self.input_generator.InputBatchSize()), tf.constant(1.0))
        # Generates summaries.
        for name, (value, weight) in six.iteritems(metrics):
            self.AddEvalMetric(name, value, weight)

        # Loss.
        self._loss, self._num_predicts = metrics['loss']
        self._loss = py_utils.CheckNumerics(self._loss)

        return metrics
示例#6
0
 def _FPropTpu(self, theta, input_batch):
   with tf.name_scope('tower_0_0'):
     metrics, per_example = self.FPropTower(theta, input_batch)
     metrics = py_utils.WeightedAvgOfMetrics([metrics])
   return metrics, per_example