Ejemplo n.º 1
0
 def model(self):
     hparams = self._hparams
     encoder_layer_stack = layer_stack_from_hparams(hparams, "encoder_")
     decoder_layer_stack = layer_stack_from_hparams(hparams, "decoder_")
     encoder = transformer.Unitransformer(
         layer_stack=encoder_layer_stack,
         d_model=hparams.d_model,
         input_vocab_size=self._inputs_vocab_size,
         output_vocab_size=None,
         autoregressive=False,
         max_length=hparams.max_length,
         name="encoder",
         layout=hparams.layout,
         mesh_shape=hparams.mesh_shape,
     )
     decoder = transformer.Unitransformer(
         layer_stack=decoder_layer_stack,
         d_model=hparams.d_model,
         input_vocab_size=self._targets_vocab_size,
         output_vocab_size=self._targets_vocab_size,
         autoregressive=True,
         max_length=hparams.max_length,
         label_smoothing=hparams.label_smoothing,
         shared_embedding_and_softmax_weights=(
             hparams.shared_embedding_and_softmax_weights),
         z_loss=hparams.z_loss,
         name="decoder",
         layout=hparams.layout,
         mesh_shape=hparams.mesh_shape,
     )
     return transformer.Bitransformer(
         encoder, decoder, shared_embedding=hparams.shared_embedding)
Ejemplo n.º 2
0
 def model(self):
     hparams = self._hparams
     if hparams.label_smoothing != 0:
         raise NotImplementedError(
             "Label smoothing not implemented in unitransformer."
             "  Do you really want it?")
     if isinstance(hparams.layer_stack, transformer.LayerStack):
         layer_stack = hparams.layer_stack
     else:
         # hparams.layer_stack is a function for creating a LayerStack
         layer_stack = hparams.layer_stack(hparams, "")
     if self.autoregressive:
         input_vocab_size = self._targets_vocab_size
     else:
         input_vocab_size = self._inputs_vocab_size
     return transformer.Unitransformer(
         layer_stack=layer_stack,
         d_model=hparams.d_model,
         input_vocab_size=input_vocab_size,
         output_vocab_size=self._targets_vocab_size,
         autoregressive=self.autoregressive,
         max_length=hparams.max_length,
         shared_embedding_and_softmax_weights=(
             hparams.shared_embedding_and_softmax_weights),
         z_loss=hparams.z_loss,
         layout=hparams.layout,
         mesh_shape=hparams.mesh_shape)
Ejemplo n.º 3
0
def build_model(model_type="bitransformer",
                input_vocab_size=gin.REQUIRED,
                output_vocab_size=gin.REQUIRED,
                layout_rules=None,
                mesh_shape=None):
    """Build a transformer model.

  Currently, four types of models are supported:

  "bitransformer": The traditional encoder-decoder architecture from
     "attention is all you need".  Requires a non-text2self dataset.

  "lm": an autoregressive language model (one layer stack).  This is similar
     to the decoder part of a bitransformer, but with no attention over an
     encoder, since there is no encoder.  Requires a text2self dataset,
     with targets, but no inputs.

  "aligned": a non-autoregressive single-stack model (like BERT).  Requires
     a non-text2self dataset with inputs and targets.  The targets are
     aligned with the inputs.

  "bi_teacher_student": a teacher-student model where both the student and
    teacher are bitransformers. Requires a non-text2self dataset.

  Args:
    model_type: a string - "bitransformer", "lm" or "aligned"
    input_vocab_size: an integer
    output_vocab_size: an integer
    layout_rules: optional - an input to mtf.convert_to_layout_rules
    mesh_shape: optional - an input to mtf.convert_to_shape
  Returns:
    a Unitransformer or Bitransformer
  """
    if model_type == "bitransformer":
        return transformer.make_bitransformer(
            input_vocab_size=input_vocab_size,
            output_vocab_size=output_vocab_size,
            mesh_shape=mesh_shape,
            layout=layout_rules)
    elif model_type == "bi_student_teacher":
        return transformer.make_bi_student_teacher(
            input_vocab_size=input_vocab_size,
            output_vocab_size=output_vocab_size,
            mesh_shape=mesh_shape,
            layout=layout_rules)
    elif model_type == "lm" or model_type == "aligned":
        return transformer.Unitransformer(
            autoregressive=model_type == "lm",
            layer_stack=transformer.make_layer_stack(),
            input_vocab_size=input_vocab_size,
            output_vocab_size=output_vocab_size,
            mesh_shape=mesh_shape,
            layout=layout_rules)
    else:
        raise ValueError("unknown model_type")
Ejemplo n.º 4
0
def get_dummy_decoder_context(converter,
                              batch=2,
                              d_model=6,
                              length=4,
                              mode="incremental",
                              initial_position=None,
                              state=None,
                              inputs=None):

    batch_dim = mtf.Dimension("batch", batch)
    length_dim = mtf.Dimension("length", length)

    # Set up a dummy model
    layer_stack = transformer.LayerStack(layers=[])
    model = transformer.Unitransformer(
        d_model=d_model,
        input_vocab_size=10,  # dummy values
        output_vocab_size=10,  # dummy values
        autoregressive=True,
        max_length=length,
        layer_stack=layer_stack)

    if state is not None:
        state_mtf = converter.convert_np_array_to_mtf_tensor(
            state, dtype=tf.float32, dim_names=["batch", "length", "d_model"])
        states = [state_mtf]
    else:
        states = None

    if initial_position:
        initial_position = mtf.constant(converter.mesh,
                                        initial_position,
                                        shape=mtf.Shape([batch_dim]),
                                        dtype=tf.int32)

    if inputs is not None:
        inputs = converter.convert_np_array_to_mtf_tensor(
            inputs, dim_names=["batch", "length"])

    context = transformer.Context(model=model,
                                  mode=mode,
                                  states=states,
                                  new_states=[],
                                  mesh=converter.mesh,
                                  batch_dims=[batch_dim],
                                  length_dim=length_dim,
                                  variable_dtype=mtf.VariableDType(tf.float32),
                                  sequence_id=1,
                                  inputs=inputs,
                                  initial_position=initial_position)
    return context
Ejemplo n.º 5
0
 def model(self):
     hparams = self._hparams
     if hparams.label_smoothing != 0:
         raise NotImplementedError(
             "Label smoothing not implemented in unitransformer."
             "  Do you really want it?")
     if isinstance(hparams.layer_stack, transformer.LayerStack):
         layer_stack = hparams.layer_stack
     else:
         # hparams.layer_stack is a function for creating a LayerStack
         layer_stack = hparams.layer_stack(hparams)
     if self.autoregressive:
         input_vocab_size = self._inputs_vocab_size
     else:
         input_vocab_size = self._targets_vocab_size
     return transformer.Unitransformer(
         layer_stack=layer_stack,
         d_model=hparams.d_model,
         input_vocab_size=input_vocab_size,
         output_vocab_size=self._targets_vocab_size,
         autoregressive=self.autoregressive,
         max_length=hparams.max_length)
Ejemplo n.º 6
0
def create_dummy_model(mesh,
                       shapes,
                       n_blocks=2,
                       block_param_size_str="2_2",
                       block_repeat_size_str="1_1"):
    """Creates a dummy model and layer stack with 4-dimensional input."""

    assert len(shapes) == 4
    outer_batch_size, batch_size, length, d_model = shapes
    batch_dim = mtf.Dimension("batch", batch_size)
    outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
    length_dim = mtf.Dimension("length", length)
    block_param_size = list(map(int, block_param_size_str.split("_")))
    block_repeat_size = list(map(int, block_repeat_size_str.split("_")))

    sublayers_initial = [
        transformer.sublayer_dropout,
    ]
    sublayers_per_layer = [
        transformer.sublayer_rms_norm,
        transformer.sublayer_call_layer,
        transformer.sublayer_dropout,
        transformer.sublayer_residual,
    ]
    sublayers_final = [
        transformer.sublayer_rms_norm,
        transformer.sublayer_dropout,
    ]
    submodules = [
        transformer_layers.SelfAttention(),
        transformer_layers.DenseReluDense()
    ]

    n_sublayers = np.array(block_param_size).prod()
    layers = submodules * n_sublayers
    layer_stack = funnel_transformer.FunnelTransformerLayerStack(
        layers=layers,
        n_blocks=n_blocks,
        block_param_size=block_param_size,
        block_repeat_size=block_repeat_size,
        sublayers_initial=sublayers_initial,
        sublayers_per_layer=sublayers_per_layer,
        sublayers_final=sublayers_final)

    model = transformer.Unitransformer(input_vocab_size=10,
                                       output_vocab_size=10,
                                       autoregressive=False,
                                       max_length=8,
                                       d_model=d_model,
                                       layer_stack=layer_stack)

    context = transformer.Context(model=model,
                                  mesh=mesh,
                                  batch_dims=[batch_dim, outer_batch_dim],
                                  length_dim=length_dim,
                                  variable_dtype=mtf.VariableDType(tf.float32),
                                  sequence_id=mtf.ones(mesh,
                                                       mtf.Shape([length_dim
                                                                  ])),
                                  position=mtf.range(mesh,
                                                     length_dim,
                                                     dtype=tf.int32))
    return layer_stack, context
Ejemplo n.º 7
0
def model(input_vocab_size,
          output_vocab_size,
          text2self,
          num_layers=gin.REQUIRED,
          d_ff=gin.REQUIRED,
          d_kv=gin.REQUIRED,
          d_model=gin.REQUIRED,
          num_heads=gin.REQUIRED,
          dropout=gin.REQUIRED,
          max_length=gin.REQUIRED,
          length=gin.REQUIRED,
          label_smoothing=gin.REQUIRED,
          layout=gin.REQUIRED,
          mesh_shape=gin.REQUIRED):
    """Build a simple Transformer model.

  Args:
    input_vocab_size: an integer
    output_vocab_size: an integer
    text2self: a boolean meaning a language model (True) or encoder/decoder
      (False)
    num_layers: integer, number of transformer layers
    d_ff: integer, size of feed-forward hidden layers
    d_kv: integer, size of attention keys/values
    d_model: integer, size of hidden state
    num_heads: integer, heads per attention layer
    dropout: float, dropout rate
    max_length: maximum sequence length (checkpoints depend on this)
    length: actual sequence length - defaults to max_length
    label_smoothing: label smoothing
    layout: a string
    mesh_shape: a string

  Returns:
    a mtf.Unitransformer or mtf.Bitransformer
  """
    # Needed for Gin injection.
    del length

    def layer_stack(include_encdec_attention):
        """Create a LayerStack.

    TODO(noam): implement a way to configure custom layer stacks using
    hyperparameters. (as in mtf_transformer2 in the tensor2tensor library).
    That functionality should go in transformer/model_builder.py

    Args:
      include_encdec_attention: a boolean

    Returns:
      a transformer.LayerStack
    """
        return model_builder.simple_layer_stack(
            include_encdec_attention=include_encdec_attention,
            num_layers=num_layers,
            d_ff=d_ff,
            d_kv=d_kv,
            num_heads=num_heads,
            dropout_rate=dropout)

    if text2self:
        return transformer.Unitransformer(
            layer_stack=layer_stack(include_encdec_attention=False),
            d_model=d_model,
            input_vocab_size=input_vocab_size,
            output_vocab_size=output_vocab_size,
            autoregressive=True,
            max_length=max_length,
            shared_embedding_and_softmax_weights=True,
            label_smoothing=label_smoothing,
            layout=layout,
            mesh_shape=mesh_shape)
    else:
        return transformer.Bitransformer(
            encoder_layer_stack=layer_stack(include_encdec_attention=False),
            decoder_layer_stack=layer_stack(include_encdec_attention=True),
            encoder_d_model=d_model,
            decoder_d_model=d_model,
            input_vocab_size=input_vocab_size,
            output_vocab_size=output_vocab_size,
            max_length=max_length,
            shared_embedding=False,
            shared_embedding_and_softmax_weights=True,
            label_smoothing=label_smoothing,
            layout=layout,
            mesh_shape=mesh_shape)