def testWorkerDeviceInModelSplitSync(self): p = cluster_factory.Cluster.Params() p.mode = 'sync' p.job = 'trainer_client' p.worker.name = '/job:trainer' p.worker.replicas = 4 p.worker.gpus_per_replica = 4 p.worker.devices_per_split = 2 c = cluster_factory.Cluster(p) with py_utils.ModelSplit(1): d = c.WorkerDeviceInModelSplit(1) expected_device = c._MakeDeviceString( job_name='/job:trainer', task_id=0, device_name='GPU', device_id=3) self.assertEqual(expected_device, d)
def _FPropSplitInputBatch(self, theta, input_batch): """Splits the input batch on the input device.""" cluster = self.cluster num_splits = cluster.num_splits_per_client if not isinstance(input_batch, list): input_batch = [input_batch] assert len(input_batch) == num_splits, (len(input_batch), num_splits) # dev_list_per_replica[i][j] is the i-th worker's j-th device. dev_list_per_replica = cluster.available_devices.tolist() # Asserts invariant of the total number of splits w.r.t., # splits per worker. splits_per_replica = cluster.num_splits_per_replica assert num_splits == splits_per_replica * len(dev_list_per_replica), ( num_splits, splits_per_replica, len(dev_list_per_replica)) all_metrics = [] all_per_example_tensors = [] for w_id, w_devs in enumerate(dev_list_per_replica): # Make local copy of the vars, shard on devices for this worker. theta_local = py_utils.CreateLocalTheta( theta, w_devs, label='worker %d' % w_id) for s_id in range(splits_per_replica): # s_id-th split for the w_id-th worker. split_id = splits_per_replica * w_id + s_id with py_utils.ModelSplit(split_id): with tf.device(cluster.WorkerDeviceInModelSplit(0)): with tf.name_scope('tower_%d_%d' % (w_id, s_id)): batch = self.input_generator.PreprocessInputBatch( input_batch[split_id]) metrics, per_example = self.FPropTower(theta_local, batch) all_metrics.append(metrics) all_per_example_tensors.append(per_example) return py_utils.WeightedAvgOfMetrics( all_metrics), py_utils.ConcatPerExampleTensors(all_per_example_tensors)
def testModelSplit(self): with py_utils.ModelSplit(2): assert py_utils.GetModelSplit() == 2 with py_utils.ModelSplit(3): assert py_utils.GetModelSplit() == 3 assert py_utils.GetModelSplit() == 0
def FProp(self, theta): """Forward propagation. This default `FProp` implementation here supports batch splitting in synchronous and asynchronous training when sub-classes implement `FPropTower`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. Returns: A dict containing metrics pairs. One of the keys should be 'loss' and its value should be a (loss, num_predictions) pair. """ p = self.params cluster = cluster_factory.Current() with tf.name_scope('fprop'), tf.name_scope(p.name): all_fprop_metrics = [] if py_utils.use_tpu(): batch = self.input_generator.CreateTpuFeeds() with tf.name_scope('tower_0_0'): dec_metrics = self.FPropTower(theta, batch) all_fprop_metrics.append(dec_metrics) else: # Splits the input batch on the input device. num_splits = cluster.num_splits_per_client with tf.device(cluster.input_device): batches = self.input_generator.SplitInputBatch(num_splits) assert num_splits == len(batches) # dev_list_per_replica[i][j] is the i-th worker's j-th device. dev_list_per_replica = cluster.available_devices.tolist() # Asserts invariant of the total number of splits w.r.t., # splits per worker. splits_per_replica = cluster.num_splits_per_replica assert num_splits == splits_per_replica * len( dev_list_per_replica) for w_id, w_devs in enumerate(dev_list_per_replica): # Make local copy of the vars, shard on devices for this worker. theta_local = py_utils.CreateLocalTheta(theta, w_devs, label='worker %d' % w_id) for s_id in range(splits_per_replica): # s_id-th split for the w_id-th worker. split_id = splits_per_replica * w_id + s_id with py_utils.ModelSplit(split_id): with tf.device( cluster.WorkerDeviceInModelSplit(0)): with tf.name_scope('tower_%d_%d' % (w_id, s_id)): batch = self.input_generator.PreprocessInputBatch( batches[split_id]) dec_metrics = self.FPropTower( theta_local, batch) all_fprop_metrics.append(dec_metrics) metrics = py_utils.WeightedAvgOfMetrics(all_fprop_metrics) # Adds stats about the input batch. metrics['num_samples_in_batch'] = (tf.convert_to_tensor( self.input_generator.InputBatchSize()), tf.constant(1.0)) # Generates summaries. for name, (value, weight) in six.iteritems(metrics): self.AddEvalMetric(name, value, weight) # Loss. self._loss, self._num_predicts = metrics['loss'] self._loss = py_utils.CheckNumerics(self._loss) return metrics