def testNoNestedPmap(self): devices = self._get_two_devices(require_same_type=True) def f(x): return x + 1.0 data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) with self.assertRaisesWithPredicateMatch(ValueError, r"Nested pmap is not supported"): f = extensions.pmap(f, devices=devices) f = extensions.pmap(f, devices=devices) f(data)
def testAxisName(self): devices = self._get_two_devices(require_same_type=True) def reduce_sum(f): return extensions.psum(f, axis_name="foo") data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices) pmapped(data)
def testWrongAxisName(self): devices = self._get_two_devices(require_same_type=True) def reduce_sum(f): return extensions.psum(f, axis_name="bar") data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) with self.assertRaisesWithPredicateMatch( ValueError, r"axis_name (.*) is not equal to that of the surrounding"): pmapped = extensions.pmap(reduce_sum, axis_name="foo", devices=devices) pmapped(data)
def testPsum(self): devices = self._get_two_devices(require_same_type=True) def reduce_sum(f): return extensions.psum(f) data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) pmapped = extensions.pmap(reduce_sum, devices=devices) result = pmapped(data) self.assertAllClose(result[0], 4) self.assertAllClose(result[1], 4)
def testPmean(self): if extensions.tpu_devices(): self.skipTest("pmean for TPU is not supported yet") devices = self._get_two_devices(require_same_type=True) def reduce_mean(f): return extensions.pmean(f) data = tf_np.asarray(tf.convert_to_tensor(value=[1, 3])) pmapped = extensions.pmap(reduce_mean, devices=devices) result = pmapped(data) self.assertAllClose(result[0], 2) self.assertAllClose(result[1], 2)
def testPmapSimpleModel(self): devices = self._get_two_devices(require_same_type=True) n_devices = len(devices) params, params_true, inputs, targets = generate_params_inputs_targets() def _train_and_reduce(params, inputs, targets, learning_rate=0.1): new_w, new_b = train_step(params, inputs, targets, learning_rate) return (extensions.psum(new_w) / n_devices, extensions.psum(new_b) / n_devices) train_step_pmapped = extensions.pmap(_train_and_reduce, devices=devices) def replicate(x, num_devices=2): return array_manipulation.broadcast_to(x, (num_devices, ) + x.shape) params = tf.nest.map_structure(replicate, params) def reshape(x, num_devices=2): x_shape = list(x.shape) batch_size = x_shape[0] batch_size_per_device = batch_size // num_devices # New shape. new_shape_prefix = [num_devices, batch_size_per_device] return array_methods.reshape(x, new_shape_prefix + x_shape[1:]) inputs = tf.nest.map_structure(reshape, inputs) targets = tf.nest.map_structure(reshape, targets) for _ in range(50): params = train_step_pmapped(params, inputs, targets) # PMAP returns sharded tensors. # Since the inputs are identical, the returned tensors should be identical self.assertAllClose(params[0][0], params[0][1]) self.assertAllClose(params[1][0], params[1][1]) # This is not trained super well, but it usually gets "close". self.assertAllClose(params[0][0], params_true[0], atol=1e-1) self.assertAllClose(params[1][0], params_true[1], atol=1e-1)
def testPsumStruct(self): devices = self._get_two_devices(require_same_type=True) def reduce_sum(a): a = extensions.psum(a) tf.nest.map_structure( lambda x: self.assertIsInstance(x, tf_np.ndarray), a) return a data = [tf_np.asarray([1, 3]), tf_np.asarray([2, 4], np.int64)] pmapped = extensions.pmap(reduce_sum, devices=devices) result = pmapped(data) self.assertIsInstance(result[0][0], tf_np.ndarray) self.assertIsInstance(result[0][1], tf_np.ndarray) self.assertIsInstance(result[1][0], tf_np.ndarray) self.assertIsInstance(result[1][1], tf_np.ndarray) self.assertAllClose(result[0][0], 4) self.assertAllClose(result[0][1], 4) self.assertAllClose(result[1][0], 6) self.assertAllClose(result[1][1], 6)
def _tf_pmap(*args, **kwargs): kwargs.pop('donate_argnums', None) # donate_argnums not used in TF return tf_np_extensions.pmap(*args, **kwargs)