def _benchmark_eager_apply(self, label, device_and_format, defun=False, execution_mode=None): with tfe.execution_mode(execution_mode): device, data_format = device_and_format model = resnet50.ResNet50(data_format) if defun: model.call = tfe.function(model.call) batch_size = 64 num_burn = 5 num_iters = 30 with tf.device(device): images, _ = random_batch(batch_size, data_format) for _ in xrange(num_burn): model(images, training=False).cpu() if execution_mode: tfe.async_wait() gc.collect() start = time.time() for _ in xrange(num_iters): model(images, training=False).cpu() if execution_mode: tfe.async_wait() self._report(label, start, num_iters, device, batch_size, data_format)
def _apply(self, defun=False, execution_mode=None): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) if defun: model.call = tfe.function(model.call) with tf.device(device), tfe.execution_mode(execution_mode): images, _ = random_batch(2, data_format) output = model(images, training=False) tfe.async_wait() self.assertEqual((2, 1000), output.shape)
def _benchmark_eager_train(self, label, make_iterator, device_and_format, defun=False, execution_mode=None): with tfe.execution_mode(execution_mode): device, data_format = device_and_format for batch_size in self._train_batch_sizes(): (images, labels) = random_batch(batch_size, data_format) model = resnet50.ResNet50(data_format) optimizer = tf.train.GradientDescentOptimizer(0.1) apply_grads = apply_gradients if defun: model.call = tfe.function(model.call) apply_grads = tfe.function(apply_gradients) num_burn = 3 num_iters = 10 with tf.device(device): iterator = make_iterator((images, labels)) for _ in xrange(num_burn): (images, labels) = iterator.next() apply_grads(model, optimizer, compute_gradients(model, images, labels)) if execution_mode: tfe.async_wait() self._force_device_sync() gc.collect() start = time.time() for _ in xrange(num_iters): (images, labels) = iterator.next() apply_grads(model, optimizer, compute_gradients(model, images, labels)) if execution_mode: tfe.async_wait() self._force_device_sync() self._report(label, start, num_iters, device, batch_size, data_format)