Exemple #1
0
 def pnum_tensor(self):
   if self._pnum_tensor is not None:
     return self._pnum_tensor
   with mtf_utils.outside_all_rewrites():
     tf.logging.info("Create pnum_tensor")
     self._pnum_tensor = tpu_ops.tpu_replicated_input(
         list(range(self.size)), name="pnum_constants")
     return self._pnum_tensor
Exemple #2
0
 def metric_fn(tf_logits, labels):
   with tf.device("cpu:0"), mtf_utils.outside_all_rewrites():
     eval_metrics = {}
     for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
       if metric_name.split("/")[-1] not in t2t_model.TPU_METRIC_BLACKLIST:
         eval_metrics[metric_name] = metric_fn(
             tf_logits, None, tf.identity(labels))
     return eval_metrics
        def __init__(self, variable, mesh_impl):
            """Create a LaidOutVariable.

      Args:
        variable: a Variable (Operation)
        mesh_impl: a MeshImpl
      """
            self._variable = variable
            self._mesh_impl = mesh_impl
            shape = variable.outputs[0].shape
            dtype = variable.outputs[0].dtype
            slice_shape = mesh_impl.slice_shape(shape)
            base_name = variable.name
            slices = []
            for pnum in xrange(mesh_impl.size):
                slice_var_name = base_name + "_slice_%d" % pnum
                tpu_device = mesh_impl.device_assignment.tpu_device(
                    replica=pnum)
                # The initializer is unimportant, since the slice variables will be
                # overwritten.  zeros_initializer() is here to avoid the default
                # initialization which adds lots of useless operations to the TF graph.
                with ops.device(tpu_device):
                    slices.append(
                        tf.get_variable(slice_var_name,
                                        slice_shape,
                                        dtype=dtype,
                                        collections=[],
                                        initializer=tf.zeros_initializer()))
            self._laid_out_tensor = mesh_impl.LaidOutTensor(
                [tpu_variables.ReplicatedVariable(base_name, slices)])
            with tf.device(
                    variable.master.device), mtf_utils.outside_all_rewrites():
                self._copy_master_to_slices = self.assign_to_slices(
                    mesh_impl.make_slices(variable.master, shape),
                    assign_to_tensor_list=slices)
                self._copy_slices_to_master = tf.assign(
                    variable.master,
                    mesh_impl.combine_slices(slices,
                                             shape,
                                             device=variable.master.device))
Exemple #4
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=False,
                           xla_compile=False):
        del xla_compile
        hparams = copy.deepcopy(hparams)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()
        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")

        mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(hparams.layout)
        if use_tpu:
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices,
                params["context"].device_assignment)
        else:
            if len(data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_shape.size
            else:
                assert len(data_parallelism.ps_devices) == mesh_shape.size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            mtf_lr = mtf.import_tf_tensor(
                mesh, tf.convert_to_tensor(lr, dtype=tf.float32),
                mtf.Shape([]))
            optimizer = mtf_optimize.make_optimizer(hparams, mtf_lr)
            update_ops = []
            for grad, var in zip(var_grads, graph.trainable_variables):
                update_ops.extend(optimizer.apply_grad(grad, var))

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf_utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            tf_logits = lowering.export_to_tf_tensor(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            _remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])
        else:
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
def model_fn(features, labels, mode, params):
  """A model is called by TpuEstimator."""
  del labels
  global_step = tf.train.get_global_step()
  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, 'my_mesh')
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  mesh_devices = [''] * mesh_shape.size
  mesh_impl = SimdMeshImpl(
      mesh_shape, mtf.convert_to_layout_rules(FLAGS.layout),
      mesh_devices, params['context'].device_assignment)
  with mtf_utils.outside_all_rewrites():
    logits, loss = toy_model(features, mesh)

  # TRAIN mode
  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients([loss],
                              [v.outputs[0] for v in graph.trainable_variables])
    optimizer = mtf_optimize.AdafactorOptimizer()
    update_ops = []
    for grad, var in zip(var_grads, graph.trainable_variables):
      update_ops.extend(optimizer.apply_grad(grad, var))
  else:
    # for now, we can only export fully-replicated tensors.
    fully_replicated_logits = mtf.anonymize(logits)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl})

  tf_loss = lowering.export_to_tf_tensor(loss)

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
    train_op = tf.group(tf_update_ops)
  else:
    tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

  with mtf_utils.outside_all_rewrites():
    # Copy master variables to slices. Must be called first.
    restore_hook = mtf.MtfRestoreHook(lowering)
    if mode == tf.estimator.ModeKeys.TRAIN:
      saver = tf.train.Saver(
          tf.global_variables(),
          sharded=True,
          max_to_keep=10,
          keep_checkpoint_every_n_hours=2,
          defer_build=False,
          save_relative_paths=True)
      tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
      saver_listener = mtf.MtfCheckpointSaverListener(lowering)
      saver_hook = tf.train.CheckpointSaverHook(
          FLAGS.model_dir,
          save_steps=1000,
          saver=saver,
          listeners=[saver_listener])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.TRAIN,
          loss=tf_loss,
          train_op=train_op,
          training_hooks=[restore_hook, saver_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:

      def metric_fn(tf_logits):
        mean_logitss = tf.metrics.mean(tf_logits)
        return {'mean_logitss': mean_logitss}

      eval_metrics = (metric_fn, [tf_logits])

      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=tf_loss,
          eval_metrics=eval_metrics)