예제 #1
0
  def run_steps_on_dataset(self, fn, iterator, iterations):
    # Enqueue ops
    shapes = nest.flatten(iterator.output_shapes)
    if any([not s.is_fully_defined() for s in shapes]):
      raise ValueError(
          'TPU currently requires fully defined shapes. Either use '
          'set_shape() on the input tensors or use '
          'dataset.apply(map_and_batch(..., drop_remainder=True)).')
    types = nest.flatten(iterator.output_types)

    def enqueue_ops_fn():
      """Enqueue ops for one iteration."""
      control_deps = []
      sharded_inputs = []
      with ops.device(self._host):
        for _ in range(self._num_cores_per_host):
          # Use control dependencies to ensure a deterministic ordering.
          with ops.control_dependencies(control_deps):
            inputs = nest.flatten(iterator.get_next())
            control_deps.extend(inputs)
            sharded_inputs.append(inputs)

      enqueue_ops = []
      for core_id, shard_input in enumerate(sharded_inputs):
        enqueue_ops.append(
            tpu_ops.infeed_enqueue_tuple(
                inputs=shard_input, shapes=shapes, device_ordinal=core_id))
      return enqueue_ops

    def enqueue_ops_loop_body(i):
      with ops.control_dependencies(enqueue_ops_fn()):
        return i + 1

    with ops.device(self._host):
      enqueue_ops = control_flow_ops.while_loop(
          lambda i: i < iterations,
          enqueue_ops_loop_body,
          [constant_op.constant(0)],
          parallel_iterations=1)

    # Dequeue ops
    def dequeue_fn():
      dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
      return nest.pack_sequence_as(iterator.output_shapes, dequeued)

    # Wrap `fn` for repeat.
    run_fn = lambda: fn(dequeue_fn())

    # Repeat
    def iterate_on_tpu():
      return tpu.repeat(iterations, run_fn, [])

    # Re-write and distribute computation.
    tpu_result = tpu.batch_parallel(
        iterate_on_tpu, [], num_shards=self._num_cores_per_host)

    return control_flow_ops.group(tpu_result, enqueue_ops)
예제 #2
0
  def _call_for_each_tower(self, fn, *args, **kwargs):
    kwargs.pop('run_concurrently', None)

    # TODO(isaprykin): Give an API for many iterations per step.
    iterations = 1

    # TODO(isaprykin): Do not hard code shapes and input format :)
    # TODO(isaprykin): Detect the number of TPU cores automatically.

    def dequeueing_fn(*args, **kwargs):
      del args, kwargs
      x, = tpu.infeed_dequeue_tuple(dtypes=[dtypes.float32], shapes=[[1, 1, 1]])
      return fn(x)

    iterator = args[0]

    def infeed_input(i):
      """Get input, split it and then enqueue."""
      batches = iterator.get_next()
      batches = array_ops.split(batches, 2)

      infeeds = [
          tpu_ops.infeed_enqueue_tuple(
              inputs=[batches[j]], shapes=[[1, 1, 1]], device_ordinal=j)
          for j in range(2)
      ]

      with ops.control_dependencies(infeeds):
        return i + 1

    with ops.device('/task:0/device:CPU:0'):
      enqueue_ops = control_flow_ops.while_loop(
          lambda i: i < iterations,
          infeed_input, [constant_op.constant(0)],
          parallel_iterations=1)

    def iterate_on_tpu():
      return tpu.repeat(iterations, dequeueing_fn, [])

    with one_device_strategy._OneDeviceTowerContext(self):  # pylint: disable=protected-access
      tpu_result = tpu.batch_parallel(iterate_on_tpu, [], num_shards=2)

    return control_flow_ops.group(tpu_result, enqueue_ops)
N = 4096
COUNT = 100


def flops():
    x = tf.random_uniform([N, N])
    y = tf.random_uniform([N, N])

    def _matmul(x, y):
        return tf.tensordot(x, y, axes=[[1], [0]]), y

    return tf.reduce_sum(tpu.repeat(COUNT, _matmul, [x, y]))


tpu_ops = tpu.batch_parallel(flops, [], num_shards=8)

session = tf.Session(tpu_cluster)

try:
    print('Warming up...')
    session.run(tpu.initialize_system())
    session.run(tpu_ops)
    print('Profiling')
    start = time.time()
    session.run(tpu_ops)
    end = time.time()
    elapsed = end - start
    print(elapsed,
          'TFlops: {:.2f}'.format(1e-12 * 8 * COUNT * 2 * N * N * N / elapsed))
except Exception as e:
  def _run_steps_on_dataset(self, fn, iterator, iterations,
                            initial_loop_values=None):
    # Enqueue ops
    shapes = nest.flatten(iterator.output_shapes)
    if any([not s.is_fully_defined() for s in shapes]):
      raise ValueError(
          'TPU currently requires fully defined shapes. Either use '
          'set_shape() on the input tensors or use '
          'dataset.apply(map_and_batch(..., drop_remainder=True)).')
    types = nest.flatten(iterator.output_types)

    def enqueue_ops_fn():
      """Enqueue ops for one iteration."""
      control_deps = []
      sharded_inputs = []
      with ops.device(self._host):
        for _ in range(self._num_cores_per_host):
          # Use control dependencies to ensure a deterministic ordering.
          with ops.control_dependencies(control_deps):
            inputs = nest.flatten(iterator.get_next())
            control_deps.extend(inputs)
            sharded_inputs.append(inputs)

      enqueue_ops = []
      for core_id, shard_input in enumerate(sharded_inputs):
        enqueue_ops.append(
            tpu_ops.infeed_enqueue_tuple(
                inputs=shard_input, shapes=shapes, device_ordinal=core_id))
      return enqueue_ops

    def enqueue_ops_loop_body(i):
      with ops.control_dependencies(enqueue_ops_fn()):
        return i + 1

    with ops.device(self._host):
      enqueue_ops = control_flow_ops.while_loop(
          lambda i: i < iterations,
          enqueue_ops_loop_body,
          [constant_op.constant(0)],
          parallel_iterations=1)

    # Dequeue ops
    def dequeue_fn():
      dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
      return nest.pack_sequence_as(iterator.output_shapes, dequeued)

    # Wrap `fn` for repeat.
    if initial_loop_values is None:
      initial_loop_values = []
    ctx = values.MultiStepContext(initial_loop_values)
    def run_fn(*args, **kwargs):
      del args, kwargs
      fn_result = fn(ctx, dequeue_fn())
      if ctx.last_step_outputs is None:
        ctx.last_step_outputs = []
      with ops.control_dependencies([fn_result]):
        return array_ops.identity(ctx.last_step_outputs)

    # Repeat
    # TODO(sourabhbajaj): The input to while loop should be based on the output
    # type of the step_fn
    def iterate_on_tpu():
      return tpu.repeat(iterations, run_fn, [initial_loop_values])

    # Re-write and distribute computation.
    # TODO(sourabhbajaj): Convert the output to PerDevice variable and
    # implement support for that in reduce.
    last_step_tensor_outputs = tpu.batch_parallel(
        iterate_on_tpu, [], num_shards=self._num_cores_per_host)

    # Take index [0] of last_step_tensor_outputs as we wrapped
    # initial_loop_values in a list in the `repeat` call.
    return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops),
            last_step_tensor_outputs[0], ctx)
예제 #5
0
    def _call_for_each_tower(self, fn, *args, **kwargs):
        kwargs.pop('run_concurrently', None)

        inputs = {'args': args, 'kwargs': kwargs}
        flat_inputs = nest.flatten(inputs)

        feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs]

        feeds = lambda: itertools.compress(flat_inputs, feed_mask)
        shapes = [f.get_shape() for f in feeds()]
        if any([not s.is_fully_defined() for s in shapes]):
            raise ValueError(
                'TPU currently requires fully defined shapes. Either use '
                'set_shape() on the input tensors or use '
                'dataset.apply(map_and_batch(..., drop_remainder=True)).')
        types = [f.get_dtype() for f in feeds()]

        def infeed_input(i):
            """Get input, split it and then enqueue."""
            iteration_inputs = [f.get(i) for f in feeds()]

            infeed_inputs = [[
                inputs_per_core[core_id]
                for inputs_per_core in iteration_inputs
            ] for core_id in range(self._num_cores_per_host)]

            infeed_ops = []
            for core_id, infeed_input in enumerate(infeed_inputs):
                infeed_ops.append(
                    tpu_ops.infeed_enqueue_tuple(inputs=infeed_input,
                                                 shapes=shapes,
                                                 device_ordinal=core_id))

            with ops.control_dependencies(infeed_ops):
                return i + 1

        with ops.device('/task:0/device:CPU:0'):
            enqueue_ops = control_flow_ops.while_loop(
                lambda i: i < self._iterations_per_step,
                infeed_input, [constant_op.constant(0)],
                parallel_iterations=1)

        def dequeueing_fn(*args, **kwargs):
            """Dequeue input arguments and supply them to `fn`."""
            del args, kwargs
            dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
            dequeued = iter(dequeued)

            fn_inputs = []
            for inp, is_feed in zip(flat_inputs, feed_mask):
                if is_feed:
                    fn_inputs.append(next(dequeued))
                else:
                    fn_inputs.append(inp)

            fn_inputs = nest.pack_sequence_as(inputs, fn_inputs)
            return fn(*fn_inputs['args'], **fn_inputs['kwargs'])

        def iterate_on_tpu():
            return tpu.repeat(self._iterations_per_step, dequeueing_fn, [])

        with one_device_strategy._OneDeviceTowerContext(self):  # pylint: disable=protected-access
            tpu_result = tpu.batch_parallel(
                iterate_on_tpu, [], num_shards=self._num_cores_per_host)

        return control_flow_ops.group(tpu_result, enqueue_ops)
예제 #6
0
  def _call_for_each_tower(self, fn, *args, **kwargs):
    kwargs.pop('run_concurrently', None)

    inputs = {'args': args, 'kwargs': kwargs}
    flat_inputs = nest.flatten(inputs)

    feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs]

    feeds = lambda: itertools.compress(flat_inputs, feed_mask)
    shapes = [f.get_shape() for f in feeds()]
    if any([not s.is_fully_defined() for s in shapes]):
      raise ValueError(
          'TPU currently requires fully defined shapes. Either use '
          'set_shape() on the input tensors or use '
          'dataset.apply(map_and_batch(..., drop_remainder=True)).')
    types = [f.get_dtype() for f in feeds()]

    def infeed_input(i):
      """Get input, split it and then enqueue."""
      iteration_inputs = [f.get(i) for f in feeds()]
      infeed_inputs = [[inputs_per_core[core_id]
                        for inputs_per_core in iteration_inputs]
                       for core_id in range(self._num_cores_per_host)]

      infeed_ops = []
      for core_id, infeed_input in enumerate(infeed_inputs):
        infeed_ops.append(
            tpu_ops.infeed_enqueue_tuple(
                inputs=infeed_input, shapes=shapes, device_ordinal=core_id))

      with ops.control_dependencies(infeed_ops):
        return i + 1

    with ops.device('/task:0/device:CPU:0'):
      enqueue_ops = control_flow_ops.while_loop(
          lambda i: i < self._iterations_per_step,
          infeed_input, [constant_op.constant(0)],
          parallel_iterations=1)

    def dequeueing_fn(*args, **kwargs):
      """Dequeue input arguments and supply them to `fn`."""
      del args, kwargs
      dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
      dequeued = iter(dequeued)

      fn_inputs = []
      for inp, is_feed in zip(flat_inputs, feed_mask):
        if is_feed:
          fn_inputs.append(next(dequeued))
        else:
          fn_inputs.append(inp)

      fn_inputs = nest.pack_sequence_as(inputs, fn_inputs)
      return fn(*fn_inputs['args'], **fn_inputs['kwargs'])

    def iterate_on_tpu():
      return tpu.repeat(self._iterations_per_step, dequeueing_fn, [])

    with one_device_strategy._OneDeviceTowerContext(self):  # pylint: disable=protected-access
      tpu_result = tpu.batch_parallel(
          iterate_on_tpu, [], num_shards=self._num_cores_per_host)

    return control_flow_ops.group(tpu_result, enqueue_ops)