Exemplo n.º 1
0
 def _OutfeedEnqueue(self, per_example_tensors):
     if not per_example_tensors:
         return tf.no_op()
     per_example_tensors = py_utils.NestedMap(per_example_tensors)
     device = tpu.core(0) if self.spmd else ''
     with tf.device(device):
         return tpu_ops.outfeed_enqueue_tuple(per_example_tensors.Flatten())
Exemplo n.º 2
0
 def _DecodeStep():
     """Decode call to be compiled for TPU."""
     input_batch = self._task.input.TpuDequeueBatch()
     metrics_dict = self._task.Decode(input_batch)
     self.metrics_nm = py_utils.NestedMap(metrics_dict)
     device = tpu.core(0) if self.spmd else ''
     with tf.device(device):
         outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
             self.metrics_nm.Flatten())
         return [outfeed_enqueue]
Exemplo n.º 3
0
            def decode_fn(*infeed_batch):  # pylint: disable=missing-docstring
                # Length 6 is passed when there is no tgt_mask (e.g. decoding) and
                # length 7 is passed when there is a tgt_mask (e.g. fprop).

                self.outfeed = self._config_outfeed(xformer, infeed_batch)

                with tf.device(tf.tpu.core(0)):
                    outfeed_op = tpu_ops.outfeed_enqueue_tuple(
                        tf.nest.flatten(self.outfeed))

                return [outfeed_op]
Exemplo n.º 4
0
 def eval_loop(self):
     per_replica_eval_batch_size = self.eval_batch_size // self.num_replicas
     tf.get_variable_scope().reuse_variables()
     predictions = tf.zeros(
         [self.eval_steps, per_replica_eval_batch_size, 2])
     _, predictions = training_loop.repeat(int(self.eval_steps),
                                           self.eval_step,
                                           [tf.constant(0), predictions])
     with tf.control_dependencies(
         [tpu_ops.outfeed_enqueue_tuple([predictions])]):
         return tf.no_op()
Exemplo n.º 5
0
 def _DecodeStep():
     """Decode call to be compiled for TPU."""
     with py_utils.OpportunisticVariableReuseScope(True):
         self._model.InstantiateVariables()
         input_batch = self._task.input.TpuDequeueBatch()
         metrics_dict = self._task.Decode(input_batch)
     self.metrics_nm = py_utils.NestedMap(metrics_dict)
     device = tpu.core(0) if self.spmd else ''
     with tf.device(device):
         outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
             self.metrics_nm.Flatten())
         return [outfeed_enqueue]
        def tpu_eval_step():
            """Generate the TPU graph."""
            values = self.eval_infeed_queue[0].generate_dequeue_op(
                tpu_device=0)
            unflattened_inputs = data_nest.pack_sequence_as(
                self.eval_feature_structure, values)
            features = unflattened_inputs["features"]
            estimator_spec = model_fn(features, None,
                                      tf.estimator.ModeKeys.PREDICT, params)
            for k, v in six.iteritems(estimator_spec.predictions):
                self.outfeed_names.append(k)
                self.outfeed_tensors.append(v)

            with tf.device(utils.device_for_tpu_core(self._get_host(0))):
                outfeed_enqueue_ops = tpu_ops.outfeed_enqueue_tuple(
                    self.outfeed_tensors)
            with tf.control_dependencies([outfeed_enqueue_ops]):
                return tf.no_op()
Exemplo n.º 7
0
 def eval_step(self):
     """One evaluation step."""
     inp = self.infeed_op[False].generate_dequeue_op()
     flatten_structure = tf.nest.flatten(self.feature_structure[False])
     inp = [
         tf.slice(i, [0] * i.shape.ndims, j.shape)
         for i, j in zip(inp, flatten_structure)
     ]
     if self.eval_has_labels:
         features, labels = tf.nest.pack_sequence_as(
             self.feature_structure[False], inp)
     else:
         features = tf.nest.pack_sequence_as(self.feature_structure[False],
                                             inp)
         labels = None
     self.maybe_add_embedding_features(features, False)
     _, self.predict_output = self.model_fn(features, labels, False)
     for _ in self.predict_output:
         self.dequeue_ops.append([])
     with tf.device(device_for_tpu_core(self.get_host(0))):
         return [
             tpu_ops.outfeed_enqueue_tuple(
                 tf.nest.flatten(self.predict_output))
         ]
Exemplo n.º 8
0
 def _OutfeedEnqueue(self, per_example_tensors):
     if not per_example_tensors:
         return tf.no_op()
     per_example_tensors = py_utils.NestedMap(per_example_tensors)
     return tpu_ops.outfeed_enqueue_tuple(per_example_tensors.Flatten())
Exemplo n.º 9
0
  def _model_fn(input_fea, input_lab):
    """Creates a model, add summary, modes (train or eval), and hooks."""

    # input_fea and input_lab should be a list (laid_out_tensors).
    if not isinstance(input_fea, list):
      input_fea = [input_fea]
    if not isinstance(input_lab, list):
      input_lab = [input_lab]

    def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step):
      """Add all summaries."""
      for k in scalars.keys():
        if not isinstance(scalars[k], tf.Tensor):
          scalars[k] = tf.cast(
              lowering.export_to_tf_tensor(scalars[k]), tf.float32)

      def _host_loss_summary(global_step, tf_loss, **scalars):
        """Add summary.scalar in host side."""
        gs = tf.cast(global_step, tf.int64)
        sum_loss = contrib_summary.scalar(
            '{}_loss'.format(train_or_eval), tf_loss, step=gs)
        sum_ops = [sum_loss.op]
        for description, tf_metric in scalars.iteritems():
          sum_metric = contrib_summary.scalar(
              '{}_{}'.format(train_or_eval, description), tf_metric, step=gs)
          sum_ops.append(sum_metric)
        with tf.control_dependencies(sum_ops):
          return tf.identity(tf_loss)

      if FLAGS.use_tpu:
        # Cast the global step to tf.int32, since
        # outside_compilation does not support tf.int64.
        tf_loss = tpu.outside_compilation(
            _host_loss_summary,
            tf.cast(global_step, tf.int32),
            tf_loss,
            **scalars)
      else:
        tf_loss = _host_loss_summary(
            tf.cast(global_step, tf.int32),
            tf_loss,
            **scalars)

      return tf_loss

    global_step = tf.train.get_or_create_global_step()
    graph, mesh, mesh_impl = mesh_context.create_graph_mesh_and_mesh_impl()

    with mtf.utils.outside_all_rewrites():
      # Do not tpu_rewrite this part. Inside this unet, If you use Tensorflow,
      # instead of Mesh-Tensorflor, it will cause host to tpu send/rec.
      preds, loss, scalars, bn_update_ops = (
          unet.unet_with_spatial_partition(
              mesh, mesh_impl, train_or_eval, input_fea, input_lab))

    if train_or_eval == 'train':
      var_grads = mtf.gradients(
          [loss], [v.outputs[0] for v in graph.trainable_variables])

      lr = FLAGS.lr * tf.pow(
          FLAGS.lr_drop_rate,
          tf.floor(tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps))
      scalars['learning_rate'] = lr

      optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr)
      update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

      # This is where the actual tf graph got built.
      lowering = mtf.Lowering(graph, {mesh: mesh_impl})

      tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
      tf_update_ops.append(tf.assign_add(global_step, 1))
      tf_update_ops.extend(
          [lowering.lowered_operation(op) for op in bn_update_ops])

    else:  # train_or_eval == 'eval':
      preds = [mtf.anonymize(pred) for pred in preds]

      # This is where the actual tf graph got built.
      lowering = mtf.Lowering(graph, {mesh: mesh_impl})

      tf_preds = [tf.cast(
          lowering.export_to_tf_tensor(pred), tf.float32) for pred in preds]

    tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
    if FLAGS.write_summary:
      tf_loss = _add_summary(
          lowering, train_or_eval, tf_loss, scalars, global_step)
    master_to_slice_hook = mtf.MtfRestoreHook(lowering)

    if train_or_eval == 'train':
      with mtf.utils.outside_all_rewrites():
        saver = tf.train.Saver(tf.global_variables(),
                               save_relative_paths=True)
        tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
        saver_listener = mtf.MtfCheckpointSaverListener(lowering)
        slice_to_master_hook = tf.train.CheckpointSaverHook(
            FLAGS.checkpoint_dir,
            save_steps=FLAGS.save_checkpoints_steps,
            saver=saver, listeners=[saver_listener])
        captured_hooks.capture([master_to_slice_hook, slice_to_master_hook])
        return tf.group([tf_loss] + tf_update_ops)

    else:  # train_or_eval == 'eval':
      if FLAGS.use_tpu:
        tf_preds.extend([tf_loss, global_step])
        tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds]
        tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds]
        captured_hooks.capture([master_to_slice_hook, None])
        captured_output_dtypes_shapes.capture(
            [tf_preds_dtypes, tf_preds_shapes])
        return tpu_ops.outfeed_enqueue_tuple(tf_preds)

      else:
        tf_preds.extend([tf_loss, global_step])
        captured_hooks.capture([master_to_slice_hook, None])
        return tf_preds
Exemplo n.º 10
0
    def _model_fn(input_fea, input_lab):
        """Creates a model, add summary, modes (train or eval), and hooks."""
        def _add_summary(lowering, train_or_eval, tf_loss, scalars,
                         global_step):
            """Add all summaries."""
            for k in scalars.keys():
                if not isinstance(scalars[k], tf.Tensor):
                    scalars[k] = tf.cast(
                        lowering.export_to_tf_tensor(scalars[k]), tf.float32)

            def _host_loss_summary(global_step, tf_loss, **scalars):
                """Add summary.scalar in host side."""
                gs = tf.cast(global_step, tf.int64)
                sum_loss = tf.contrib.summary.scalar(
                    '{}_loss'.format(train_or_eval), tf_loss, step=gs)
                sum_ops = [sum_loss.op]
                for description, tf_metric in scalars.iteritems():
                    sum_metric = tf.contrib.summary.scalar('{}_{}'.format(
                        train_or_eval, description),
                                                           tf_metric,
                                                           step=gs)
                    sum_ops.append(sum_metric)
                with tf.control_dependencies(sum_ops):
                    return tf.identity(tf_loss)

            # Cast the global step to tf.int32, since
            # outside_compilation does not support tf.int64.
            tf_loss = tpu.outside_compilation(_host_loss_summary,
                                              tf.cast(global_step, tf.int32),
                                              tf_loss, **scalars)

            return tf_loss

        global_step = tf.train.get_or_create_global_step()
        graph = mtf.Graph()

        # Worker 0 caches all the TPU binaries.
        replica_cache_size = 300 * 1024 * 1024  # 300M per replica.
        worker0_mem = replica_cache_size * 8 * num_hosts
        devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1)

        tf.logging.info('cpu_devices: {}, devices_mem: {}'.format(
            cpu_devices, devices_memory_usage))
        var_placer = mtf.utils.BalancedVariablePlacer(cpu_devices,
                                                      devices_memory_usage)

        mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = unet.get_layout()
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                    None, d_assignment)

        with mtf.utils.outside_all_rewrites():  # Do not tpu_rewrite this part.
            preds, loss, scalars, bn_update_ops = (
                unet.unet_with_spatial_partition(mesh, train_or_eval,
                                                 input_fea, input_lab))

        if train_or_eval == 'train':
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])

            lr = FLAGS.lr * tf.pow(
                FLAGS.lr_drop_rate,
                tf.floor(
                    tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps))
            scalars['learning_rate'] = lr

            optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

            # This is where the actual tf graph got built.
            lowering = mtf.Lowering(graph, {mesh: mesh_impl})

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf_update_ops.extend(
                [lowering.lowered_operation(op) for op in bn_update_ops])
            tf_update_ops_group = tf.group(tf_update_ops)

        else:  # train_or_eval == 'eval':
            preds = [mtf.anonymize(pred) for pred in preds]

            # This is where the actual tf graph got built.
            lowering = mtf.Lowering(graph, {mesh: mesh_impl})

            tf_preds = [
                tf.cast(lowering.export_to_tf_tensor(pred), tf.float32)
                for pred in preds
            ]

        tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
        if FLAGS.write_summary:
            tf_loss = _add_summary(lowering, train_or_eval, tf_loss, scalars,
                                   global_step)
        master_to_slice_hook = mtf.MtfRestoreHook(lowering)

        if train_or_eval == 'train':
            with mtf.utils.outside_all_rewrites():
                saver = tf.train.Saver(tf.global_variables(),
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                slice_to_master_hook = tf.train.CheckpointSaverHook(
                    FLAGS.checkpoint_dir,
                    save_steps=FLAGS.save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                captured_hooks.capture(
                    [master_to_slice_hook, slice_to_master_hook])
                return tf_update_ops_group

        else:  # train_or_eval == 'eval':
            tf_preds.extend([tf_loss, global_step])
            tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds]
            tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds]
            captured_hooks.capture([master_to_slice_hook, None])
            captured_output_dtypes_shapes.capture(
                [tf_preds_dtypes, tf_preds_shapes])
            return tpu_ops.outfeed_enqueue_tuple(tf_preds)