def _FPropTpu(self, theta, input_batch): p = self.params with tf.name_scope('fprop'), tf.name_scope(p.name): with tf.name_scope('tower_0_0'): metrics = self.FPropTower(theta, input_batch) metrics = py_utils.WeightedAvgOfMetrics([metrics]) return metrics
def testWeightedAvgOfMetrics(self): with self.session(use_gpu=False) as sess: metrics = [{ 'a': (2.0, 0.5), 'b': (5.0, 1.5) }, { 'a': (9.0, 3.0), 'b': (4.0, 0.5) }] expected = {'a': (8.0, 3.5), 'b': (4.75, 2.0)} weighted_avg = py_utils.WeightedAvgOfMetrics(metrics) actual = sess.run(weighted_avg) self.assertDictEqual(actual, expected)
def _FPropSplitInputBatch(self, theta, input_batch): """Splits the input batch on the input device.""" if py_utils.use_tpu(): return self._FPropTpu(theta, input_batch) cluster = self.cluster num_splits = cluster.num_splits_per_client if not isinstance(input_batch, list): input_batch = [input_batch] assert len(input_batch) == num_splits, (len(input_batch), num_splits) # dev_list_per_replica[i][j] is the i-th worker's j-th device. dev_list_per_replica = cluster.available_devices.tolist() # Asserts invariant of the total number of splits w.r.t., # splits per worker. splits_per_replica = cluster.num_splits_per_replica assert num_splits == splits_per_replica * len(dev_list_per_replica), ( num_splits, splits_per_replica, len(dev_list_per_replica)) all_metrics = [] all_per_example_tensors = [] with cluster: for w_id, w_devs in enumerate(dev_list_per_replica): # Make local copy of the vars, shard on devices for this worker. theta_local = py_utils.CreateLocalTheta( theta, w_devs, label='worker %d' % w_id) for s_id in range(splits_per_replica): # s_id-th split for the w_id-th worker. split_id = splits_per_replica * w_id + s_id with cluster_factory.SetModelSplit(split_id) as c: with tf.device(c.WorkerDeviceInModelSplit(0)): with tf.name_scope('tower_%d_%d' % (w_id, s_id)): batch = input_batch[split_id] metrics, per_example = self.FPropTower(theta_local, batch) all_metrics.append(metrics) all_per_example_tensors.append(per_example) return py_utils.WeightedAvgOfMetrics( all_metrics), py_utils.ConcatPerExampleTensors(all_per_example_tensors)
def _FPropSplitInputBatch(self, theta, input_batch): """Splits the input batch on the input device.""" cluster = self.cluster num_splits = cluster.num_splits_per_client if not isinstance(input_batch, list): input_batch = [input_batch] assert len(input_batch) == num_splits, (len(input_batch), num_splits) # dev_list_per_replica[i][j] is the i-th worker's j-th device. dev_list_per_replica = cluster.available_devices.tolist() # Asserts invariant of the total number of splits w.r.t., # splits per worker. splits_per_replica = cluster.num_splits_per_replica assert num_splits == splits_per_replica * len(dev_list_per_replica), ( num_splits, splits_per_replica, len(dev_list_per_replica)) all_fprop_metrics = [] for w_id, w_devs in enumerate(dev_list_per_replica): # Make local copy of the vars, shard on devices for this worker. theta_local = py_utils.CreateLocalTheta(theta, w_devs, label='worker %d' % w_id) for s_id in range(splits_per_replica): # s_id-th split for the w_id-th worker. split_id = splits_per_replica * w_id + s_id with py_utils.ModelSplit(split_id): with tf.device(cluster.WorkerDeviceInModelSplit(0)): with tf.name_scope('tower_%d_%d' % (w_id, s_id)): batch = self.input_generator.PreprocessInputBatch( input_batch[split_id]) dec_metrics = self.FPropTower(theta_local, batch) all_fprop_metrics.append(dec_metrics) return py_utils.WeightedAvgOfMetrics(all_fprop_metrics)
def FProp(self, theta): """Forward propagation. This default `FProp` implementation here supports batch splitting in synchronous and asynchronous training when sub-classes implement `FPropTower`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. Returns: A dict containing metrics pairs. One of the keys should be 'loss' and its value should be a (loss, num_predictions) pair. """ p = self.params cluster = cluster_factory.Current() with tf.name_scope('fprop'), tf.name_scope(p.name): all_fprop_metrics = [] if py_utils.use_tpu(): batch = self.input_generator.CreateTpuFeeds() with tf.name_scope('tower_0_0'): dec_metrics = self.FPropTower(theta, batch) all_fprop_metrics.append(dec_metrics) else: # Splits the input batch on the input device. num_splits = cluster.num_splits_per_client with tf.device(cluster.input_device): batches = self.input_generator.SplitInputBatch(num_splits) assert num_splits == len(batches) # dev_list_per_replica[i][j] is the i-th worker's j-th device. dev_list_per_replica = cluster.available_devices.tolist() # Asserts invariant of the total number of splits w.r.t., # splits per worker. splits_per_replica = cluster.num_splits_per_replica assert num_splits == splits_per_replica * len( dev_list_per_replica) for w_id, w_devs in enumerate(dev_list_per_replica): # Make local copy of the vars, shard on devices for this worker. theta_local = py_utils.CreateLocalTheta(theta, w_devs, label='worker %d' % w_id) for s_id in range(splits_per_replica): # s_id-th split for the w_id-th worker. split_id = splits_per_replica * w_id + s_id with py_utils.ModelSplit(split_id): with tf.device( cluster.WorkerDeviceInModelSplit(0)): with tf.name_scope('tower_%d_%d' % (w_id, s_id)): batch = self.input_generator.PreprocessInputBatch( batches[split_id]) dec_metrics = self.FPropTower( theta_local, batch) all_fprop_metrics.append(dec_metrics) metrics = py_utils.WeightedAvgOfMetrics(all_fprop_metrics) # Adds stats about the input batch. metrics['num_samples_in_batch'] = (tf.convert_to_tensor( self.input_generator.InputBatchSize()), tf.constant(1.0)) # Generates summaries. for name, (value, weight) in six.iteritems(metrics): self.AddEvalMetric(name, value, weight) # Loss. self._loss, self._num_predicts = metrics['loss'] self._loss = py_utils.CheckNumerics(self._loss) return metrics
def _FPropTpu(self, theta, input_batch): with tf.name_scope('tower_0_0'): metrics, per_example = self.FPropTower(theta, input_batch) metrics = py_utils.WeightedAvgOfMetrics([metrics]) return metrics, per_example