def testModelSplit(self): p = cluster_factory.Cluster.Params() with p.Instantiate() as c: with cluster_factory.SetModelSplit(2) as c1: self.assertEqual(c1.params.split_id, 2) with cluster_factory.SetModelSplit(3) as c2: self.assertEqual(c2.params.split_id, 3) self.assertEqual(c.params.split_id, 0)
def testWorkerDeviceInModelSplitSync(self): p = cluster_factory.Cluster.Params() p.mode = 'sync' p.job = 'trainer_client' p.worker.name = '/job:trainer' p.worker.replicas = 4 p.worker.gpus_per_replica = 4 p.worker.devices_per_split = 2 with cluster_factory.Cluster(p): with cluster_factory.SetModelSplit(1) as c: d = c.WorkerDeviceInModelSplit(1) expected_device = c._MakeDeviceString( job_name='/job:trainer', task_id=0, device_name='GPU', device_id=3) self.assertEqual(expected_device, d)
def _FPropSplitInputBatch(self, theta, input_batch): """Splits the input batch on the input device.""" if py_utils.use_tpu(): return self._FPropTpu(theta, input_batch) cluster = self.cluster num_splits = cluster.num_splits_per_client if not isinstance(input_batch, list): input_batch = [input_batch] assert len(input_batch) == num_splits, (len(input_batch), num_splits) # dev_list_per_replica[i][j] is the i-th worker's j-th device. dev_list_per_replica = cluster.available_devices.tolist() # Asserts invariant of the total number of splits w.r.t., # splits per worker. splits_per_replica = cluster.num_splits_per_replica assert num_splits == splits_per_replica * len(dev_list_per_replica), ( num_splits, splits_per_replica, len(dev_list_per_replica)) all_metrics = [] all_per_example_tensors = [] with cluster: for w_id, w_devs in enumerate(dev_list_per_replica): # Make local copy of the vars, shard on devices for this worker. theta_local = py_utils.CreateLocalTheta( theta, w_devs, label='worker %d' % w_id) for s_id in range(splits_per_replica): # s_id-th split for the w_id-th worker. split_id = splits_per_replica * w_id + s_id with cluster_factory.SetModelSplit(split_id) as c: with tf.device(c.WorkerDeviceInModelSplit(0)): with tf.name_scope('tower_%d_%d' % (w_id, s_id)): batch = input_batch[split_id] metrics, per_example = self.FPropTower(theta_local, batch) all_metrics.append(metrics) all_per_example_tensors.append(per_example) return py_utils.WeightedAvgOfMetrics( all_metrics), py_utils.ConcatPerExampleTensors(all_per_example_tensors)