def run_steps_on_dataset(self, fn, iterator, iterations): # Enqueue ops shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') types = nest.flatten(iterator.output_types) def enqueue_ops_fn(): """Enqueue ops for one iteration.""" control_deps = [] sharded_inputs = [] with ops.device(self._host): for _ in range(self._num_cores_per_host): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): inputs = nest.flatten(iterator.get_next()) control_deps.extend(inputs) sharded_inputs.append(inputs) enqueue_ops = [] for core_id, shard_input in enumerate(sharded_inputs): enqueue_ops.append( tpu_ops.infeed_enqueue_tuple( inputs=shard_input, shapes=shapes, device_ordinal=core_id)) return enqueue_ops def enqueue_ops_loop_body(i): with ops.control_dependencies(enqueue_ops_fn()): return i + 1 with ops.device(self._host): enqueue_ops = control_flow_ops.while_loop( lambda i: i < iterations, enqueue_ops_loop_body, [constant_op.constant(0)], parallel_iterations=1) # Dequeue ops def dequeue_fn(): dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. run_fn = lambda: fn(dequeue_fn()) # Repeat def iterate_on_tpu(): return tpu.repeat(iterations, run_fn, []) # Re-write and distribute computation. tpu_result = tpu.batch_parallel( iterate_on_tpu, [], num_shards=self._num_cores_per_host) return control_flow_ops.group(tpu_result, enqueue_ops)
def _call_for_each_tower(self, fn, *args, **kwargs): kwargs.pop('run_concurrently', None) # TODO(isaprykin): Give an API for many iterations per step. iterations = 1 # TODO(isaprykin): Do not hard code shapes and input format :) # TODO(isaprykin): Detect the number of TPU cores automatically. def dequeueing_fn(*args, **kwargs): del args, kwargs x, = tpu.infeed_dequeue_tuple(dtypes=[dtypes.float32], shapes=[[1, 1, 1]]) return fn(x) iterator = args[0] def infeed_input(i): """Get input, split it and then enqueue.""" batches = iterator.get_next() batches = array_ops.split(batches, 2) infeeds = [ tpu_ops.infeed_enqueue_tuple( inputs=[batches[j]], shapes=[[1, 1, 1]], device_ordinal=j) for j in range(2) ] with ops.control_dependencies(infeeds): return i + 1 with ops.device('/task:0/device:CPU:0'): enqueue_ops = control_flow_ops.while_loop( lambda i: i < iterations, infeed_input, [constant_op.constant(0)], parallel_iterations=1) def iterate_on_tpu(): return tpu.repeat(iterations, dequeueing_fn, []) with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access tpu_result = tpu.batch_parallel(iterate_on_tpu, [], num_shards=2) return control_flow_ops.group(tpu_result, enqueue_ops)
N = 4096 COUNT = 100 def flops(): x = tf.random_uniform([N, N]) y = tf.random_uniform([N, N]) def _matmul(x, y): return tf.tensordot(x, y, axes=[[1], [0]]), y return tf.reduce_sum(tpu.repeat(COUNT, _matmul, [x, y])) tpu_ops = tpu.batch_parallel(flops, [], num_shards=8) session = tf.Session(tpu_cluster) try: print('Warming up...') session.run(tpu.initialize_system()) session.run(tpu_ops) print('Profiling') start = time.time() session.run(tpu_ops) end = time.time() elapsed = end - start print(elapsed, 'TFlops: {:.2f}'.format(1e-12 * 8 * COUNT * 2 * N * N * N / elapsed)) except Exception as e:
def _run_steps_on_dataset(self, fn, iterator, iterations, initial_loop_values=None): # Enqueue ops shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') types = nest.flatten(iterator.output_types) def enqueue_ops_fn(): """Enqueue ops for one iteration.""" control_deps = [] sharded_inputs = [] with ops.device(self._host): for _ in range(self._num_cores_per_host): # Use control dependencies to ensure a deterministic ordering. with ops.control_dependencies(control_deps): inputs = nest.flatten(iterator.get_next()) control_deps.extend(inputs) sharded_inputs.append(inputs) enqueue_ops = [] for core_id, shard_input in enumerate(sharded_inputs): enqueue_ops.append( tpu_ops.infeed_enqueue_tuple( inputs=shard_input, shapes=shapes, device_ordinal=core_id)) return enqueue_ops def enqueue_ops_loop_body(i): with ops.control_dependencies(enqueue_ops_fn()): return i + 1 with ops.device(self._host): enqueue_ops = control_flow_ops.while_loop( lambda i: i < iterations, enqueue_ops_loop_body, [constant_op.constant(0)], parallel_iterations=1) # Dequeue ops def dequeue_fn(): dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(iterator.output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = [] ctx = values.MultiStepContext(initial_loop_values) def run_fn(*args, **kwargs): del args, kwargs fn_result = fn(ctx, dequeue_fn()) if ctx.last_step_outputs is None: ctx.last_step_outputs = [] with ops.control_dependencies([fn_result]): return array_ops.identity(ctx.last_step_outputs) # Repeat # TODO(sourabhbajaj): The input to while loop should be based on the output # type of the step_fn def iterate_on_tpu(): return tpu.repeat(iterations, run_fn, [initial_loop_values]) # Re-write and distribute computation. # TODO(sourabhbajaj): Convert the output to PerDevice variable and # implement support for that in reduce. last_step_tensor_outputs = tpu.batch_parallel( iterate_on_tpu, [], num_shards=self._num_cores_per_host) # Take index [0] of last_step_tensor_outputs as we wrapped # initial_loop_values in a list in the `repeat` call. return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops), last_step_tensor_outputs[0], ctx)
def _call_for_each_tower(self, fn, *args, **kwargs): kwargs.pop('run_concurrently', None) inputs = {'args': args, 'kwargs': kwargs} flat_inputs = nest.flatten(inputs) feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs] feeds = lambda: itertools.compress(flat_inputs, feed_mask) shapes = [f.get_shape() for f in feeds()] if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') types = [f.get_dtype() for f in feeds()] def infeed_input(i): """Get input, split it and then enqueue.""" iteration_inputs = [f.get(i) for f in feeds()] infeed_inputs = [[ inputs_per_core[core_id] for inputs_per_core in iteration_inputs ] for core_id in range(self._num_cores_per_host)] infeed_ops = [] for core_id, infeed_input in enumerate(infeed_inputs): infeed_ops.append( tpu_ops.infeed_enqueue_tuple(inputs=infeed_input, shapes=shapes, device_ordinal=core_id)) with ops.control_dependencies(infeed_ops): return i + 1 with ops.device('/task:0/device:CPU:0'): enqueue_ops = control_flow_ops.while_loop( lambda i: i < self._iterations_per_step, infeed_input, [constant_op.constant(0)], parallel_iterations=1) def dequeueing_fn(*args, **kwargs): """Dequeue input arguments and supply them to `fn`.""" del args, kwargs dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) dequeued = iter(dequeued) fn_inputs = [] for inp, is_feed in zip(flat_inputs, feed_mask): if is_feed: fn_inputs.append(next(dequeued)) else: fn_inputs.append(inp) fn_inputs = nest.pack_sequence_as(inputs, fn_inputs) return fn(*fn_inputs['args'], **fn_inputs['kwargs']) def iterate_on_tpu(): return tpu.repeat(self._iterations_per_step, dequeueing_fn, []) with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access tpu_result = tpu.batch_parallel( iterate_on_tpu, [], num_shards=self._num_cores_per_host) return control_flow_ops.group(tpu_result, enqueue_ops)
def _call_for_each_tower(self, fn, *args, **kwargs): kwargs.pop('run_concurrently', None) inputs = {'args': args, 'kwargs': kwargs} flat_inputs = nest.flatten(inputs) feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs] feeds = lambda: itertools.compress(flat_inputs, feed_mask) shapes = [f.get_shape() for f in feeds()] if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') types = [f.get_dtype() for f in feeds()] def infeed_input(i): """Get input, split it and then enqueue.""" iteration_inputs = [f.get(i) for f in feeds()] infeed_inputs = [[inputs_per_core[core_id] for inputs_per_core in iteration_inputs] for core_id in range(self._num_cores_per_host)] infeed_ops = [] for core_id, infeed_input in enumerate(infeed_inputs): infeed_ops.append( tpu_ops.infeed_enqueue_tuple( inputs=infeed_input, shapes=shapes, device_ordinal=core_id)) with ops.control_dependencies(infeed_ops): return i + 1 with ops.device('/task:0/device:CPU:0'): enqueue_ops = control_flow_ops.while_loop( lambda i: i < self._iterations_per_step, infeed_input, [constant_op.constant(0)], parallel_iterations=1) def dequeueing_fn(*args, **kwargs): """Dequeue input arguments and supply them to `fn`.""" del args, kwargs dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) dequeued = iter(dequeued) fn_inputs = [] for inp, is_feed in zip(flat_inputs, feed_mask): if is_feed: fn_inputs.append(next(dequeued)) else: fn_inputs.append(inp) fn_inputs = nest.pack_sequence_as(inputs, fn_inputs) return fn(*fn_inputs['args'], **fn_inputs['kwargs']) def iterate_on_tpu(): return tpu.repeat(self._iterations_per_step, dequeueing_fn, []) with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access tpu_result = tpu.batch_parallel( iterate_on_tpu, [], num_shards=self._num_cores_per_host) return control_flow_ops.group(tpu_result, enqueue_ops)