Пример #1
0
 def model(self):
     hparams = self._hparams
     encoder_layer_stack = layer_stack_from_hparams(hparams, "encoder_")
     decoder_layer_stack = layer_stack_from_hparams(hparams, "decoder_")
     encoder = transformer.Unitransformer(
         layer_stack=encoder_layer_stack,
         d_model=hparams.d_model,
         input_vocab_size=self._inputs_vocab_size,
         output_vocab_size=None,
         autoregressive=False,
         max_length=hparams.max_length,
         name="encoder",
         layout=hparams.layout,
         mesh_shape=hparams.mesh_shape,
     )
     decoder = transformer.Unitransformer(
         layer_stack=decoder_layer_stack,
         d_model=hparams.d_model,
         input_vocab_size=self._targets_vocab_size,
         output_vocab_size=self._targets_vocab_size,
         autoregressive=True,
         max_length=hparams.max_length,
         label_smoothing=hparams.label_smoothing,
         shared_embedding_and_softmax_weights=(
             hparams.shared_embedding_and_softmax_weights),
         z_loss=hparams.z_loss,
         name="decoder",
         layout=hparams.layout,
         mesh_shape=hparams.mesh_shape,
     )
     return transformer.Bitransformer(
         encoder, decoder, shared_embedding=hparams.shared_embedding)
Пример #2
0
 def model(self):
     hparams = self._hparams
     if isinstance(hparams.encoder_layer_stack, transformer.LayerStack):
         encoder_layer_stack = hparams.encoder_layer_stack
     else:
         encoder_layer_stack = hparams.encoder_layer_stack(
             hparams, "encoder_")
     if isinstance(hparams.decoder_layer_stack, transformer.LayerStack):
         decoder_layer_stack = hparams.decoder_layer_stack
     else:
         decoder_layer_stack = hparams.decoder_layer_stack(
             hparams, "decoder_")
     return transformer.Bitransformer(
         encoder_layer_stack=encoder_layer_stack,
         decoder_layer_stack=decoder_layer_stack,
         encoder_d_model=hparams.d_model,
         decoder_d_model=hparams.d_model,
         input_vocab_size=self._inputs_vocab_size,
         output_vocab_size=self._targets_vocab_size,
         max_length=hparams.max_length,
         shared_embedding=hparams.shared_embedding,
         shared_embedding_and_softmax_weights=(
             hparams.shared_embedding_and_softmax_weights),
         label_smoothing=hparams.label_smoothing,
         z_loss=hparams.z_loss,
         layout=hparams.layout,
         mesh_shape=hparams.mesh_shape)
Пример #3
0
 def model(self):
     hparams = self._hparams
     if isinstance(hparams.encoder_layer_stack, transformer.LayerStack):
         encoder_layer_stack = hparams.encoder_layer_stack
     else:
         encoder_layer_stack = hparams.encoder_layer_stack(hparams)
     if isinstance(hparams.decoder_layer_stack, transformer.LayerStack):
         decoder_layer_stack = hparams.decoder_layer_stack
     else:
         decoder_layer_stack = hparams.decoder_layer_stack(hparams)
     return transformer.Bitransformer(
         encoder_layer_stack=encoder_layer_stack,
         decoder_layer_stack=decoder_layer_stack,
         encoder_d_model=hparams.d_model,
         decoder_d_model=hparams.d_model,
         input_vocab_size=self._inputs_vocab_size,
         output_vocab_size=self._targets_vocab_size,
         max_length=hparams.max_length,
         shared_embedding=hparams.shared_embedding,
         label_smoothing=hparams.label_smoothing)
Пример #4
0
def my_model_fn(features,
                labels,
                mode,
                params=None,
                config=None):
  """Estimator model function.

  Args:
    features: input features dictionary
    labels: ignored
    mode: a tf.estimator.ModeKeys
    params: something
    config: something
  Returns:
    something
  """
  del labels, config
  use_tpu = FLAGS.tpu
  global_step = tf.train.get_global_step()

  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
  if use_tpu:
    ctx = params["context"]
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    # TODO(ylc): Better estimation of replica cache size?
    replica_cache_size = 300 * 1000000  # 300M per replica
    # Worker 0 caches all the TPU binaries.
    worker0_mem = replica_cache_size * ctx.num_replicas
    devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
    var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                  devices_memeory_usage)
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
        mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
  else:
    var_placer = None
    mesh_devices = [""] * mesh_shape.size
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        mesh_shape, layout_rules, mesh_devices)

  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "my_mesh", var_placer)

  model = transformer.Bitransformer(
      encoder_layer_stack=layer_stack(include_encdec_attention=False),
      decoder_layer_stack=layer_stack(include_encdec_attention=True),
      encoder_d_model=FLAGS.d_model,
      decoder_d_model=FLAGS.d_model,
      input_vocab_size=transformer_dataset.padded_vocab_size(
          transformer_dataset.inputs_vocab_size(FLAGS.dataset)),
      output_vocab_size=transformer_dataset.padded_vocab_size(
          transformer_dataset.targets_vocab_size(FLAGS.dataset)),
      max_length=FLAGS.max_length,
      shared_embedding=False,
      shared_embedding_and_softmax_weights=True,
      label_smoothing=FLAGS.label_smoothing,
      layout=FLAGS.layout,
      mesh_shape=FLAGS.mesh_shape)

  inputs = import_feature(features, mesh, "inputs")

  # Data-types used for variables and activations
  # See comments in the FLAGS
  master_dtype = tf.as_dtype(FLAGS.master_dtype)
  if FLAGS.slice_dtype:
    slice_dtype = tf.as_dtype(FLAGS.slice_dtype)
  elif not FLAGS.tpu or FLAGS.mode == "train":
    slice_dtype = tf.float32
  else:
    slice_dtype = tf.bfloat16
  if FLAGS.activation_dtype:
    activation_dtype = tf.as_dtype(FLAGS.activation_dtype)
  else:
    activation_dtype = tf.bfloat16 if FLAGS.tpu else tf.float32
  variable_dtype = mtf.VariableDType(master_dtype=master_dtype,
                                     slice_dtype=slice_dtype,
                                     activation_dtype=activation_dtype)

  # PREDICT mode
  if mode == tf.estimator.ModeKeys.PREDICT:
    mtf_samples = model.decode(
        inputs,
        variable_dtype=variable_dtype,
        beam_size=FLAGS.beam_size,
        alpha=FLAGS.alpha,
        temperature=FLAGS.temperature)
    mtf_samples = mtf.anonymize(mtf_samples)
    lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack)
    outputs = lowering.export_to_tf_tensor(mtf_samples)
    predictions = {
        "outputs": outputs
    }
    return tpu_estimator.TPUEstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        prediction_hooks=[mtf.MtfRestoreHook(lowering)])

  targets = import_feature(features, mesh, "targets")
  anon_targets = mtf.anonymize(targets)
  logits, loss = model.call_simple(
      inputs=inputs,
      targets=targets,
      compute_loss=True,
      mode=mode,
      variable_dtype=variable_dtype,
      encoder_sequence_id=import_feature(features, mesh, "inputs_segmentation"),
      decoder_sequence_id=import_feature(
          features, mesh, "targets_segmentation"),
      encoder_position=import_feature(features, mesh, "inputs_position"),
      decoder_position=import_feature(features, mesh, "targets_position")
  )

  if use_tpu and logits is not None:
    logits = mtf.anonymize(logits)

  # TRAIN mode
  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients(
        [loss], [v.outputs[0] for v in graph.trainable_variables])
    optimizer = mtf.optimize.AdafactorOptimizer()
    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack)

  tf_loss = lowering.export_to_tf_tensor(loss)
  tf_loss = tf.to_float(tf_loss)
  if not use_tpu:
    tf_loss = tf.Print(
        tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss")
  if logits and mode != tf.estimator.ModeKeys.TRAIN:
    tf_logits = lowering.export_to_tf_tensor(logits)

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    train_op = tf.group(tf_update_ops)

  with mtf.utils.outside_all_rewrites():
    # Copy master variables to slices. Must be called first.
    restore_hook = mtf.MtfRestoreHook(lowering)
    saver = tf.train.Saver(
        tf.global_variables(),
        sharded=True,
        max_to_keep=10,
        keep_checkpoint_every_n_hours=2,
        defer_build=False,
        save_relative_paths=True)
    tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
    saver_listener = mtf.MtfCheckpointSaverListener(lowering)
    saver_hook = tf.train.CheckpointSaverHook(
        FLAGS.model_dir,
        save_steps=1000,
        saver=saver,
        listeners=[saver_listener])

    if mode == tf.estimator.ModeKeys.TRAIN:
      if use_tpu:
        return tpu_estimator.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.TRAIN,
            loss=tf_loss,
            train_op=train_op,
            training_hooks=[restore_hook, saver_hook])
      else:
        return tf.estimator.EstimatorSpec(
            tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
            training_chief_hooks=[restore_hook, saver_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:
      def padded_neg_log_perplexity(logits, labels):
        weights = tf.to_float(tf.not_equal(labels, 0))
        xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=logits)
        return {"neg_log_perplexity": tf.metrics.mean(-xent, weights)}
      labels = lowering.export_to_tf_tensor(anon_targets)
      eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels])
      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=tf_loss,
          eval_metrics=eval_metrics)
Пример #5
0
def model(input_vocab_size,
          output_vocab_size,
          text2self,
          num_layers=gin.REQUIRED,
          d_ff=gin.REQUIRED,
          d_kv=gin.REQUIRED,
          d_model=gin.REQUIRED,
          num_heads=gin.REQUIRED,
          dropout=gin.REQUIRED,
          max_length=gin.REQUIRED,
          length=gin.REQUIRED,
          label_smoothing=gin.REQUIRED,
          layout=gin.REQUIRED,
          mesh_shape=gin.REQUIRED):
    """Build a simple Transformer model.

  Args:
    input_vocab_size: an integer
    output_vocab_size: an integer
    text2self: a boolean meaning a language model (True) or encoder/decoder
      (False)
    num_layers: integer, number of transformer layers
    d_ff: integer, size of feed-forward hidden layers
    d_kv: integer, size of attention keys/values
    d_model: integer, size of hidden state
    num_heads: integer, heads per attention layer
    dropout: float, dropout rate
    max_length: maximum sequence length (checkpoints depend on this)
    length: actual sequence length - defaults to max_length
    label_smoothing: label smoothing
    layout: a string
    mesh_shape: a string

  Returns:
    a mtf.Unitransformer or mtf.Bitransformer
  """
    # Needed for Gin injection.
    del length

    def layer_stack(include_encdec_attention):
        """Create a LayerStack.

    TODO(noam): implement a way to configure custom layer stacks using
    hyperparameters. (as in mtf_transformer2 in the tensor2tensor library).
    That functionality should go in transformer/model_builder.py

    Args:
      include_encdec_attention: a boolean

    Returns:
      a transformer.LayerStack
    """
        return model_builder.simple_layer_stack(
            include_encdec_attention=include_encdec_attention,
            num_layers=num_layers,
            d_ff=d_ff,
            d_kv=d_kv,
            num_heads=num_heads,
            dropout_rate=dropout)

    if text2self:
        return transformer.Unitransformer(
            layer_stack=layer_stack(include_encdec_attention=False),
            d_model=d_model,
            input_vocab_size=input_vocab_size,
            output_vocab_size=output_vocab_size,
            autoregressive=True,
            max_length=max_length,
            shared_embedding_and_softmax_weights=True,
            label_smoothing=label_smoothing,
            layout=layout,
            mesh_shape=mesh_shape)
    else:
        return transformer.Bitransformer(
            encoder_layer_stack=layer_stack(include_encdec_attention=False),
            decoder_layer_stack=layer_stack(include_encdec_attention=True),
            encoder_d_model=d_model,
            decoder_d_model=d_model,
            input_vocab_size=input_vocab_size,
            output_vocab_size=output_vocab_size,
            max_length=max_length,
            shared_embedding=False,
            shared_embedding_and_softmax_weights=True,
            label_smoothing=label_smoothing,
            layout=layout,
            mesh_shape=mesh_shape)