Пример #1
0
  def testCreateLocalTheta(self):
    methods = [py_utils.WeightInit.Gaussian, py_utils.WeightInit.Uniform]
    dtypes = [tf.float32, tf.complex64]
    shapes = [[2, 4], [3]]

    test_vars = py_utils.NestedMap()
    for i, (m, dt, sp) in enumerate(itertools.product(methods, dtypes, shapes)):
      pc = py_utils.WeightParams(sp, m(), dt, 'col1')
      test_vars['var_%d' % i] = py_utils.CreateVariable('var_%d' % i, pc)[0]

    test_devices = [
        '/job:worker/replica:0/device:GPU:0',
        '/job:worker/replica:0/device:GPU:1'
    ]

    sharded_local_vars = py_utils.CreateLocalTheta(test_vars, test_devices)
    sharded_local_vars_list = sharded_local_vars.Flatten()

    # assert the name is now Identity*
    for v in sharded_local_vars_list:
      self.assertTrue('Identity' in v.name)

    # assert proper device placement
    for i, v in enumerate(sharded_local_vars_list):
      expected_device = test_devices[i % len(test_devices)]
      self.assertEqual(v.device, expected_device)
Пример #2
0
  def _FPropSplitInputBatch(self, theta, input_batch):
    """Splits the input batch on the input device."""
    if py_utils.use_tpu():
      return self._FPropTpu(theta, input_batch)

    cluster = self.cluster
    num_splits = cluster.num_splits_per_client

    if not isinstance(input_batch, list):
      input_batch = [input_batch]

    assert len(input_batch) == num_splits, (len(input_batch), num_splits)

    # dev_list_per_replica[i][j] is the i-th worker's j-th device.
    dev_list_per_replica = cluster.available_devices.tolist()

    # Asserts invariant of the total number of splits w.r.t.,
    # splits per worker.
    splits_per_replica = cluster.num_splits_per_replica
    assert num_splits == splits_per_replica * len(dev_list_per_replica), (
        num_splits, splits_per_replica, len(dev_list_per_replica))

    all_metrics = []
    all_per_example_tensors = []
    with cluster:
      for w_id, w_devs in enumerate(dev_list_per_replica):
        # Make local copy of the vars, shard on devices for this worker.
        theta_local = py_utils.CreateLocalTheta(
            theta, w_devs, label='worker %d' % w_id)

        for s_id in range(splits_per_replica):
          # s_id-th split for the w_id-th worker.
          split_id = splits_per_replica * w_id + s_id
          with cluster_factory.SetModelSplit(split_id) as c:
            with tf.device(c.WorkerDeviceInModelSplit(0)):
              with tf.name_scope('tower_%d_%d' % (w_id, s_id)):
                batch = input_batch[split_id]
                metrics, per_example = self.FPropTower(theta_local, batch)
          all_metrics.append(metrics)
          all_per_example_tensors.append(per_example)
    return py_utils.WeightedAvgOfMetrics(
        all_metrics), py_utils.ConcatPerExampleTensors(all_per_example_tensors)
Пример #3
0
    def _FPropSplitInputBatch(self, theta, input_batch):
        """Splits the input batch on the input device."""
        cluster = self.cluster
        num_splits = cluster.num_splits_per_client

        if not isinstance(input_batch, list):
            input_batch = [input_batch]

        assert len(input_batch) == num_splits, (len(input_batch), num_splits)

        # dev_list_per_replica[i][j] is the i-th worker's j-th device.
        dev_list_per_replica = cluster.available_devices.tolist()

        # Asserts invariant of the total number of splits w.r.t.,
        # splits per worker.
        splits_per_replica = cluster.num_splits_per_replica
        assert num_splits == splits_per_replica * len(dev_list_per_replica), (
            num_splits, splits_per_replica, len(dev_list_per_replica))

        all_fprop_metrics = []
        for w_id, w_devs in enumerate(dev_list_per_replica):
            # Make local copy of the vars, shard on devices for this worker.
            theta_local = py_utils.CreateLocalTheta(theta,
                                                    w_devs,
                                                    label='worker %d' % w_id)

            for s_id in range(splits_per_replica):
                # s_id-th split for the w_id-th worker.
                split_id = splits_per_replica * w_id + s_id
                with py_utils.ModelSplit(split_id):
                    with tf.device(cluster.WorkerDeviceInModelSplit(0)):
                        with tf.name_scope('tower_%d_%d' % (w_id, s_id)):
                            batch = self.input_generator.PreprocessInputBatch(
                                input_batch[split_id])
                            dec_metrics = self.FPropTower(theta_local, batch)
                all_fprop_metrics.append(dec_metrics)

        return py_utils.WeightedAvgOfMetrics(all_fprop_metrics)
Пример #4
0
    def FProp(self, theta):
        """Forward propagation.

    This default `FProp` implementation here supports batch splitting in
    synchronous and asynchronous training when sub-classes implement
    `FPropTower`.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.

    Returns:
      A dict containing metrics pairs. One of the keys should be 'loss' and its
      value should be a (loss, num_predictions) pair.
    """
        p = self.params
        cluster = cluster_factory.Current()

        with tf.name_scope('fprop'), tf.name_scope(p.name):
            all_fprop_metrics = []

            if py_utils.use_tpu():
                batch = self.input_generator.CreateTpuFeeds()
                with tf.name_scope('tower_0_0'):
                    dec_metrics = self.FPropTower(theta, batch)
                all_fprop_metrics.append(dec_metrics)
            else:
                # Splits the input batch on the input device.
                num_splits = cluster.num_splits_per_client
                with tf.device(cluster.input_device):
                    batches = self.input_generator.SplitInputBatch(num_splits)
                    assert num_splits == len(batches)

                # dev_list_per_replica[i][j] is the i-th worker's j-th device.
                dev_list_per_replica = cluster.available_devices.tolist()

                # Asserts invariant of the total number of splits w.r.t.,
                # splits per worker.
                splits_per_replica = cluster.num_splits_per_replica
                assert num_splits == splits_per_replica * len(
                    dev_list_per_replica)

                for w_id, w_devs in enumerate(dev_list_per_replica):
                    # Make local copy of the vars, shard on devices for this worker.
                    theta_local = py_utils.CreateLocalTheta(theta,
                                                            w_devs,
                                                            label='worker %d' %
                                                            w_id)

                    for s_id in range(splits_per_replica):
                        # s_id-th split for the w_id-th worker.
                        split_id = splits_per_replica * w_id + s_id
                        with py_utils.ModelSplit(split_id):
                            with tf.device(
                                    cluster.WorkerDeviceInModelSplit(0)):
                                with tf.name_scope('tower_%d_%d' %
                                                   (w_id, s_id)):
                                    batch = self.input_generator.PreprocessInputBatch(
                                        batches[split_id])
                                    dec_metrics = self.FPropTower(
                                        theta_local, batch)
                        all_fprop_metrics.append(dec_metrics)

            metrics = py_utils.WeightedAvgOfMetrics(all_fprop_metrics)

        # Adds stats about the input batch.
        metrics['num_samples_in_batch'] = (tf.convert_to_tensor(
            self.input_generator.InputBatchSize()), tf.constant(1.0))
        # Generates summaries.
        for name, (value, weight) in six.iteritems(metrics):
            self.AddEvalMetric(name, value, weight)

        # Loss.
        self._loss, self._num_predicts = metrics['loss']
        self._loss = py_utils.CheckNumerics(self._loss)

        return metrics