Exemplo n.º 1
0
def c2r3d(cfield, dims, norm=None, dtype=tf.float32, name=None):
    """
  Converts a complex Fourier domain field to a real field

  Parameters:
  -----------
  cfield: tensor (batch_size, nc, nc, nc)
    Complex 3D real field

  norm: float
    Normalization factor

  dtype: tf.dtype
    Type of output tensor

  Return:
  -------
  rfield: tensor (batch_size, nc, nc, nc)
    Real valued field
  """
    x_dim, y_dim, z_dim = cfield.shape[-3:]
    if norm is None:
        norm = mtf.constant(cfield.mesh, x_dim.size * y_dim.size * z_dim.size)
    rfield = mtf.cast(mesh_ops.ifft3d(cfield, dims), dtype) * norm
    return rfield
Exemplo n.º 2
0
def positional_encoding(x):
    seq_len_dim, model_dim = x.shape[1:]
    seq_len_size = seq_len_dim.size
    model_size = model_dim.size

    # mtf.constant is only to create a tensor with a constant scalar val. But as
    # long as the tensor is fully replicated, initializing it with a tensor
    # works.
    assert (not seq_len_dim.name.startswith('axis'))
    assert (not model_dim.name.startswith('axis'))

    # Values for positional encoder
    pos = np.arange(seq_len_size).reshape(-1, 1)
    val = np.power(10000, (2 * np.arange(model_size)) / model_size,
                   dtype=float)
    pos_enc_values = pos / val
    np.sin(pos_enc_values[:, ::2],
           out=pos_enc_values[:, ::2],
           dtype=np.float32)
    np.cos(pos_enc_values[:, 1::2],
           out=pos_enc_values[:, 1::2],
           dtype=np.float32)

    # positional encoder
    pos_enc = mtf.constant(x.mesh,
                           pos_enc_values,
                           shape=mtf.Shape([seq_len_dim, model_dim]),
                           dtype=tf.float32)
    return (x * math.sqrt(model_size)) + pos_enc
Exemplo n.º 3
0
def r2c3d(rfield, k_dims, norm=None, dtype=tf.complex64):
    """
  Converts a real field to its complex Fourier Transform

  Parameters:
  -----------
  rfield: tensor (batch_size, nc, nc, nc)
    Input 3D real field

  norm: float
    Normalization factor

  dtype: tf.dtype
    Type of output tensor

  Return:
  -------
  cfield: tensor (batch_size, nc, nc, nc)
    Complex field
  """
    x_dim, y_dim, z_dim = rfield.shape[-3:]
    if norm is None:
        norm = mtf.constant(rfield.mesh, x_dim.size * y_dim.size * z_dim.size)
    cfield = mesh_ops.fft3d(mtf.cast(rfield / norm, dtype), k_dims)
    return cfield
Exemplo n.º 4
0
    def test_conv1d_update_state(self):
        batch = 2
        d_model = 6
        filter_size = 3
        batch_dim = mtf.Dimension("batch", batch)
        filter_dim = mtf.Dimension("filter", filter_size)

        x = np.random.randn(batch, d_model)
        x_mtf = self.converter.convert_np_array_to_mtf_tensor(
            x, dtype=tf.float32, dim_names=["batch", "d_model"])

        old_state = np.random.randn(batch, filter_size, d_model)
        old_state_mtf = self.converter.convert_np_array_to_mtf_tensor(
            old_state,
            dtype=tf.float32,
            dim_names=["batch", "filter", "d_model"])

        position_mtf = mtf.constant(self.converter.mesh,
                                    filter_size - 1,
                                    shape=mtf.Shape([batch_dim]),
                                    dtype=tf.int32)
        conv_layer = transformer_layers.Conv1D()
        output_mtf = conv_layer.update_state(old_state_mtf,
                                             x_mtf,
                                             position_mtf,
                                             filter_dim,
                                             dtype=tf.float32)
        actual = self.converter.convert_mtf_tensor_to_np_array(output_mtf)

        expected = np.empty(shape=old_state.shape)
        expected[:, :filter_size - 1, :] = old_state[:, 1:, :]
        expected[:, -1, :] = x
        self.assertAllClose(actual, expected)
Exemplo n.º 5
0
            def computation_fn():
                graph = mtf.Graph()
                mesh = mtf.Mesh(graph, 'my_mesh')
                mesh_shape = mtf.convert_to_shape('all:2')
                layout = 'none:all'
                mesh_devices = [''] * mesh_shape.size
                mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                    mesh_shape, mtf.convert_to_layout_rules(layout),
                    mesh_devices, device_assignment)
                hidden_dim = mtf.Dimension('hidden', 3)
                w = mtf.get_variable(mesh,
                                     'w',
                                     shape=[hidden_dim],
                                     initializer=tf.constant_initializer(
                                         [0.1, -0.2, -0.1]))
                x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                                 dtype=tf.float32)
                loss = mtf.reduce_mean(mtf.square(x - w))

                lr, update_ops = optimization_lib.create_optimizer(
                    loss, 0.2, 100, 10)
                self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})

                tf_update_ops = [
                    self.lowering.lowered_operation(op) for op in update_ops
                ]
                tf_update_ops.append(
                    tf.assign_add(tf.train.get_or_create_global_step(), 1))
                train_op = tf.group(tf_update_ops)

                return lr, train_op
Exemplo n.º 6
0
    def sample(self, features, mesh):
        hparams = self._hparams
        model = self.model()

        def import_feature(key):
            return self._import_feature(features, mesh, key)

        if self.autoregressive:
            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = import_feature("inputs")
            if partial_targets is None:
                partial_targets = import_feature("targets")
            if partial_targets:
                partial_targets *= mtf.cast(mtf.not_equal(partial_targets, 1),
                                            partial_targets.dtype)
            else:
                ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
                partial_targets = mtf.constant(mesh,
                                               0,
                                               ids_shape,
                                               dtype=tf.int32)
            if hparams.beam_size > 1:
                raise NotImplementedError(
                    "Beam search not implemented for unitransformer.")
            ret = model.sample_autoregressive(
                partial_targets,
                temperature=hparams.sampling_temp,
                variable_dtype=self.variable_dtype)
            return self.combine_batch_dims(ret)
        else:
            raise ValueError(
                "Don't know how to sample from non-autoregressive unitransformer"
            )
Exemplo n.º 7
0
 def computation_fn():
     graph = mtf.Graph()
     mesh = mtf.Mesh(graph, 'my_mesh')
     mesh_shape = mtf.convert_to_shape('all:2')
     layout = 'none:all'
     mesh_devices = [''] * mesh_shape.size
     mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
         mesh_shape, mtf.convert_to_layout_rules(layout),
         mesh_devices, device_assignment)
     hidden_dim = mtf.Dimension('hidden', 3)
     w = mtf.get_variable(mesh,
                          'w',
                          shape=[hidden_dim],
                          initializer=tf.constant_initializer(
                              [0.1, -0.2, -0.1]))
     x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim],
                      dtype=tf.float32)
     loss = mtf.reduce_mean(mtf.square(x - w))
     var_grads = mtf.gradients(
         [loss], [v.outputs[0] for v in graph.trainable_variables])
     optimizer = mtf_optimize.AdamWeightDecayOptimizer(
         learning_rate=0.2)
     update_ops = optimizer.apply_grads(var_grads,
                                        graph.trainable_variables)
     self.lowering = mtf.Lowering(graph, {mesh: mesh_impl})
     tf_update_ops = [
         self.lowering.lowered_operation(op) for op in update_ops
     ]
     return tf.group(tf_update_ops)
Exemplo n.º 8
0
 def test_convert_mtf_tensor_to_np_array(self):
     x_np = np.array([[1, 2, 3], [4, 5, 6]])
     converter = test_utils.NumpyConverter()
     shape = mtf.Shape([mtf.Dimension("dim0", 2), mtf.Dimension("dim1", 3)])
     x_mtf = mtf.constant(converter.mesh, x_np, shape=shape, dtype=tf.int32)
     actual = converter.convert_mtf_tensor_to_np_array(x_mtf)
     self.assertAllEqual(x_np, actual)
Exemplo n.º 9
0
  def testWhileLoopOperation(self):
    # This test case implements the following:
    # for i in range(10):
    #   x = x * 2
    i = mtf.constant(self.mesh, 0, mtf.Shape([]))
    cond_fn = lambda i, x: mtf.less(i, 10)
    body_fn = lambda i, x: [mtf.add(i, 1), mtf.multiply(x, 2)]

    while_loop_operation = mtf.WhileLoopOperation(cond_fn, body_fn, [i, self.x])
    self.assertEqual(while_loop_operation.splittable_dims,
                     frozenset(["a", "b"]))
    self.assertEqual(while_loop_operation.unsplittable_dims, frozenset())
Exemplo n.º 10
0
  def convert_np_array_to_mtf_tensor(self, x, dim_names=None, dtype=tf.int32):
    """Convert a numpy array to an equivalent mtf.Tensor."""
    dim_sizes = x.shape
    if not dim_names:
      dim_names = [f"dim{i}" for i in range(len(dim_sizes))]

    dims = []
    for dim_size, dim_name in zip(dim_sizes, dim_names):
      dims.append(mtf.Dimension(dim_name, dim_size))
    shape = mtf.Shape(dims)
    x_mtf = mtf.constant(self.mesh, x, shape=shape, dtype=dtype)
    return x_mtf
Exemplo n.º 11
0
def get_dummy_decoder_context(converter,
                              batch=2,
                              d_model=6,
                              length=4,
                              mode="incremental",
                              initial_position=None,
                              state=None,
                              inputs=None):

    batch_dim = mtf.Dimension("batch", batch)
    length_dim = mtf.Dimension("length", length)

    # Set up a dummy model
    layer_stack = transformer.LayerStack(layers=[])
    model = transformer.Unitransformer(
        d_model=d_model,
        input_vocab_size=10,  # dummy values
        output_vocab_size=10,  # dummy values
        autoregressive=True,
        max_length=length,
        layer_stack=layer_stack)

    if state is not None:
        state_mtf = converter.convert_np_array_to_mtf_tensor(
            state, dtype=tf.float32, dim_names=["batch", "length", "d_model"])
        states = [state_mtf]
    else:
        states = None

    if initial_position:
        initial_position = mtf.constant(converter.mesh,
                                        initial_position,
                                        shape=mtf.Shape([batch_dim]),
                                        dtype=tf.int32)

    if inputs is not None:
        inputs = converter.convert_np_array_to_mtf_tensor(
            inputs, dim_names=["batch", "length"])

    context = transformer.Context(model=model,
                                  mode=mode,
                                  states=states,
                                  new_states=[],
                                  mesh=converter.mesh,
                                  batch_dims=[batch_dim],
                                  length_dim=length_dim,
                                  variable_dtype=mtf.VariableDType(tf.float32),
                                  sequence_id=1,
                                  inputs=inputs,
                                  initial_position=initial_position)
    return context
Exemplo n.º 12
0
    def sample(self, features, mesh):
        hparams = self._hparams
        model = self.model()
        # Prepare partial targets.
        # In either features["inputs"] or features["targets"].
        # We force the outputs to begin with these sequences.
        partial_targets = features.get("inputs", None)
        if partial_targets is None:
            partial_targets = features.get("targets", None)
        if partial_targets is not None:
            partial_targets = common_layers.expand_squeeze_to_nd(
                partial_targets, 2)
            partial_targets = tf.to_int32(partial_targets)
            partial_targets_batch = tf.shape(partial_targets)[0]
            partial_targets_length = tf.shape(partial_targets)[1]
            partial_targets = tf.pad(
                partial_targets,
                [[0, hparams.batch_size - partial_targets_batch],
                 [0, self.length_dim.size - partial_targets_length]])
            partial_targets = self._import_to_batch_by_length(
                partial_targets, "partial_targets", mesh)
            # strip EOS
            partial_targets *= mtf.to_int32(mtf.not_equal(partial_targets, 1))

        else:
            ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
            partial_targets = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
        if hparams.beam_size == 1:
            pass
        else:
            raise NotImplementedError("not implemented")
            # beam_dim = mtf.Dimension("beam", hparams.beam_size)
            # ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim])

        partial_targets = mtf.Print(partial_targets, [partial_targets],
                                    "Partial_Targets",
                                    summarize=1000)
        return model.sample_autoregressive(partial_targets,
                                           temperature=hparams.sampling_temp,
                                           variable_dtype=self.variable_dtype)
Exemplo n.º 13
0
 def body_fn(position, ids, *states):
     """One step in the decode loop."""
     context_incremental = Context(
         mesh=inputs.mesh,
         batch_dims=batch_dims,
         length_dim=length_dim,
         model_dim=self.model_dim,
         variable_dtype=variable_dtype,
         mode="incremental",
         autoregressive=self.autoregressive,
         position=position,
         states=states,
         new_states=[],
         sequence_id=sequence_id,
         encoder_output=encoder_output,
         encoder_sequence_id=encoder_sequence_id,
         constant_states=constant_states,
         shared_params=shared_params,
         layout=self.layout,
         mesh_shape=self.mesh_shape,
         encoder_layer_outputs=encoder_layer_outputs)
     inputs_this_step = mtf.gather(ids, position - 1, length_dim)
     with tf.variable_scope(self.name, reuse=True):
         logits = self._call_internal(context_incremental,
                                      inputs_this_step)
         if never_end:
             logits += mtf.one_hot(mtf.constant(logits.mesh,
                                                stop_at_token,
                                                dtype=tf.int32),
                                   self.output_vocab_dim,
                                   on_value=-1e9,
                                   off_value=0.0,
                                   dtype=logits.dtype)
     ids_this_step = mtf.sample_with_temperature(
         logits, self.output_vocab_dim, temperature)
     new_position = position + 1
     new_ids = ids + ids_this_step * mtf.one_hot(
         position, length_dim, dtype=tf.int32)
     return [new_position, new_ids] + context_incremental.new_states
Exemplo n.º 14
0
  def get_masked_lm_output(self, positions, label_ids, label_weights):
    """Get loss and logits for the masked LM."""
    input_tensor = self.get_sequence_output()
    output_weights = self.get_embedding_table()

    # [batch_size, num_position, hidden]
    input_tensor = mtf.gather(input_tensor, positions, self.seq_dim)
    with tf.variable_scope("cls/predictions"):
      # We apply one more non-linear transformation before the output layer.
      # This matrix is not used after pre-training.
      with tf.variable_scope("transform"):
        input_tensor = mtf.layers.dense(
            input_tensor,
            reduced_dims=[self.model_dim],
            new_dims=[self.model_dim],
            activation=get_activation(self.config.feedforward_intermediate_act),
            kernel_initializer=self.dense_initializer,
            use_bias=self.config.use_bias)
        input_tensor = self.normalize(input_tensor)
      # The output weights are the same as the input embeddings, but there is
      # an output-only bias for each token.
      output_bias = mtf.get_variable(
          input_tensor.mesh,
          name="output_bias",
          shape=[self.vocab_dim],
          initializer=tf.zeros_initializer())
      logits = mtf.einsum([input_tensor, output_weights],
                          reduced_dims=[self.model_dim]) + output_bias
      per_example_loss = mtf.layers.softmax_cross_entropy_with_logits(
          logits, label_ids, self.vocab_dim, z_loss=1e-4)
      # The `positions` tensor might be zero-padded (if the sequence is too
      # short to have the maximum number of predictions). The `label_weights`
      # tensor has a value of 1.0 for every real prediction and 0.0 for the
      # padding predictions.
      numerator = mtf.reduce_sum(label_weights * per_example_loss)
      denominator = mtf.reduce_sum(label_weights) + mtf.constant(
          input_tensor.mesh, 1e-5, dtype=tf.float32)
      loss = numerator / denominator
    return (loss, per_example_loss, logits)
Exemplo n.º 15
0
    def _sample(self, features, mesh):
        hparams = self._hparams
        (inputs_embedding_var, targets_embedding_var, softmax_var,
         positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
        if hparams.transformer_type == "encdec":
            inputs = features["inputs"]
            while len(inputs.shape.as_list()) > 2:
                inputs = tf.squeeze(inputs, axis=2)
            actual_batch_size = tf.shape(inputs)[0]
            actual_length = tf.shape(inputs)[1]
            inputs = tf.pad(inputs,
                            [[0, hparams.batch_size - actual_batch_size],
                             [0, hparams.max_length - actual_length]])
            inputs = self._import_to_batch_by_length(inputs, "inputs", mesh,
                                                     hparams)
            x = (mtf.gather(inputs_embedding_var, inputs,
                            self.inputs_vocab_dim) +
                 mtf.reshape(positional_embedding_var,
                             mtf.Shape([self.length_dim, self.model_dim])))
            encoder_attention_mask = (mtf.layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))
            with tf.variable_scope("encoder"):
                x = self._layer_stack(
                    x,
                    hparams.encoder_layers,
                    self_attention_mask=encoder_attention_mask)
            encoder_output = mtf.rename_dimension(x, self.length_dim.name,
                                                  self.memory_length_dim.name)
            encdec_tensors = []
            for layer_num, layer_type in enumerate(hparams.decoder_layers):
                if layer_type == "enc_att":
                    with tf.variable_scope("decoder/enc_att_%d/enc_att" %
                                           layer_num):
                        q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars(
                            mesh, self.heads_dim, self.model_dim, self.kv_dim,
                            self.master_dtype, self.slice_dtype,
                            self.activation_dtype)
                        k = mtf.einsum([encoder_output, k_var],
                                       mtf.Shape(self.batch_dims + [
                                           self.heads_dim,
                                           self.memory_length_dim, self.kv_dim
                                       ]))
                        v = mtf.einsum([encoder_output, v_var],
                                       mtf.Shape(self.batch_dims + [
                                           self.heads_dim,
                                           self.memory_length_dim, self.kv_dim
                                       ]))
                    encdec_tensors.append((q_var, o_var, k, v))
                else:
                    encdec_tensors.append(None)
            partial_targets = None
        elif hparams.transformer_type == "decoder":
            encdec_tensors = None
            encoder_output = None
            encoder_attention_mask = None
            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get("inputs", None)
            if partial_targets is None:
                partial_targets = features.get("targets", None)
            if partial_targets is not None:
                partial_targets = common_layers.expand_squeeze_to_nd(
                    partial_targets, 2)
                partial_targets = tf.to_int32(partial_targets)
                partial_targets_batch = tf.shape(partial_targets)[0]
                partial_targets_length = tf.shape(partial_targets)[1]
                partial_targets = tf.pad(
                    partial_targets,
                    [[0, hparams.batch_size - partial_targets_batch],
                     [0, hparams.max_length - partial_targets_length]])
                partial_targets = self._import_to_batch_by_length(
                    partial_targets, "partial_targets", mesh, hparams)
        else:
            raise ValueError("hparams.model_type = %s not yet supported" %
                             hparams.transformer_type)

        local_attention_window = mtf.Dimension(
            "local_attention_window", hparams.local_attention_window_size)
        if hparams.beam_size == 1:
            ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
            kv_shape = mtf.Shape(
                self.batch_dims +
                [self.heads_dim, self.memory_length_dim, self.kv_dim])
            local_kv_shape = mtf.Shape(
                self.batch_dims +
                [self.heads_dim, local_attention_window, self.kv_dim])
        else:
            beam_dim = mtf.Dimension("beam", hparams.beam_size)
            ids_shape = mtf.Shape(self.batch_dims +
                                  [beam_dim, self.length_dim])
            kv_shape = mtf.Shape(self.batch_dims + [
                beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim
            ])
            local_kv_shape = mtf.Shape(self.batch_dims + [
                beam_dim, self.heads_dim, local_attention_window, self.kv_dim
            ])

        initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
        initial_states = []
        for layer in hparams.decoder_layers:
            if layer == "att":
                initial_states.extend(
                    [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] *
                    2)
            elif layer == "local_att":
                initial_states.extend([
                    mtf.zeros(
                        mesh, local_kv_shape, dtype=self.activation_dtype)
                ] * 2)

        def logits_fn(step_num, ids, states):
            """Produce logits for this step, and new states."""
            ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
            x = (mtf.gather(targets_embedding_var, ids_this_step,
                            self.targets_vocab_dim) +
                 mtf.gather(positional_embedding_var, step_num,
                            self.max_length_dim))
            with tf.variable_scope("decoder"):
                x, new_states = self._layer_stack(
                    x,
                    hparams.decoder_layers,
                    encdec_attention_mask=encoder_attention_mask,
                    step_num=step_num,
                    encdec_tensors=encdec_tensors,
                    states=states)
            logits = mtf.matmul(x, softmax_var)
            return logits, new_states

        if hparams.beam_size == 1:
            temperature = (0.0 if hparams.sampling_method == "argmax" else
                           hparams.sampling_temp)
            return mtf.beam_search.greedy_decode(logits_fn,
                                                 initial_ids,
                                                 temperature=temperature,
                                                 initial_states=initial_states,
                                                 forced_ids=partial_targets,
                                                 use_tpu=hparams.use_tpu)
        else:
            if hparams.transformer_type == "encdec":
                input_length = mtf.reduce_sum(mtf.to_float(
                    mtf.cast(inputs, tf.bool)),
                                              reduced_dim=self.length_dim)
                max_input_length = mtf.reduce_max(input_length)
                decode_length = mtf.cast(
                    max_input_length * hparams.decode_length_multiplier +
                    hparams.decode_length_constant, tf.int32)
            else:
                decode_length = None
            beams, unused_scores = mtf.beam_search.beam_search(
                logits_fn,
                initial_ids,
                hparams.alpha,
                states=initial_states,
                decode_length=decode_length,
                use_tpu=hparams.use_tpu,
                dtype=self.activation_dtype)
            return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32),
                              beam_dim)
Exemplo n.º 16
0
    def decode(self,
               inputs,
               variable_dtype=mtf.VariableDType(tf.float32),
               beam_size=1,
               alpha=0.6,
               temperature=1.0,
               decode_length_multiplier=1.5,
               decode_length_constant=10):
        """Sampling or beam search.

    TODO(noam): should we make the output length dimension different from the
    input length dimension?

    Args:
      inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim]
      variable_dtype: a mtf.VariableDType
      beam_size: an integer >= 1
      alpha: a floating point value (length bonus for beam search)
      temperature: a value between 0 and 1 (must be 0 if beam_size > 1)
      decode_length_multiplier: a float
      decode_length_constant: a float

    Returns:
      a Tensor with shape [<batch_dims>, beam_dim, length_dim]
    """
        shared_params = self._shared_params(inputs.mesh, variable_dtype)
        encoder_sequence_id = mtf.minimum(inputs, 1)
        encoder_output, encoder_loss = self.encoder.call_simple(
            inputs=inputs,
            targets=None,
            compute_loss=False,
            mode=tf.estimator.ModeKeys.PREDICT,
            variable_dtype=variable_dtype,
            sequence_id=encoder_sequence_id,
            shared_params=shared_params)
        del encoder_loss
        encoder_output = mtf.layers.rename_length_to_memory_length(
            encoder_output)
        encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
            encoder_sequence_id)
        if beam_size == 1:
            ids_shape = inputs.shape
            partial_targets = mtf.constant(inputs.mesh,
                                           0,
                                           ids_shape,
                                           dtype=tf.int32)
            return self.decoder.sample_autoregressive(
                partial_targets,
                temperature=temperature,
                variable_dtype=variable_dtype,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                shared_params=shared_params)
        else:
            if temperature != 0:
                raise ValueError(
                    "don't know how to beam search with nonzero temperature")
            # beam search
            beam_dim = mtf.Dimension("beam", beam_size)
            batch_dims = inputs.shape[:-1]
            length_dim = inputs.shape[-1]
            ids_shape = mtf.Shape(batch_dims + [beam_dim, length_dim])
            partial_targets = mtf.constant(inputs.mesh,
                                           0,
                                           ids_shape,
                                           dtype=tf.int32)
            input_length = mtf.reduce_sum(mtf.to_float(
                mtf.cast(inputs, tf.bool)),
                                          reduced_dim=length_dim)
            max_input_length = mtf.reduce_max(input_length)
            decode_length = mtf.cast(
                max_input_length * decode_length_multiplier +
                decode_length_constant, tf.int32)
            return self.decoder.beam_search(
                partial_targets,
                decode_length,
                variable_dtype=variable_dtype,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                alpha=alpha,
                shared_params=shared_params)
Exemplo n.º 17
0
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
    """Creates and returns an optimizer training op."""
    global_step = tf.train.get_or_create_global_step()

    # set defaults
    end_step = params.get("lr_decay_end", params["train_steps"])
    lr_decay = params.get("lr_decay", "cosine")
    warmup_steps = params.get("warmup_steps", 3000)
    gradient_clipping = params.get("gradient_clipping", 1.0)
    optimizer_name = params.get("optimizer", "adam")

    learning_rate = tf.constant(value=params["lr"],
                                shape=[],
                                dtype=variable_dtype.slice_dtype)
    clip_value = mtf.constant(mesh,
                              gradient_clipping,
                              dtype=variable_dtype.slice_dtype)

    if inp_var_grads is None:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in mesh.graph.trainable_variables])
    else:
        var_grads = inp_var_grads

    valid_grads_vars = list(
        filter(lambda grad_var: grad_var[0] is not None,
               zip(var_grads, mesh.graph.trainable_variables)))
    valid_vars = [var for grad, var in valid_grads_vars]
    valid_grad = [grad for grad, var in valid_grads_vars]

    tf.logging.info([
        v for v in zip(var_grads,
                       [v.outputs[0] for v in mesh.graph.trainable_variables])
    ])
    # Cast to full precision
    var_grads_fp = [
        mtf.cast(v, variable_dtype.slice_dtype) for v in valid_grad
    ]

    if lr_decay == "linear":
        learning_rate = tf.train.polynomial_decay(
            learning_rate,
            global_step,
            end_step,
            end_learning_rate=params["lr"] *
            0.1,  # Decrease to 10% of initial LR according to GPT-3 paper
            power=1.0,
            cycle=False)
    elif lr_decay == "cosine":
        learning_rate = tf.train.cosine_decay(
            learning_rate,
            global_step,
            end_step,
            alpha=0.1  # Alpha is min lr value as a fraction of init lr.
        )

    if warmup_steps > 0:
        global_steps_int = tf.cast(global_step, tf.int32)
        warmup_steps_int = tf.constant(warmup_steps, dtype=tf.int32)

        dtype = variable_dtype.slice_dtype

        global_steps_float = tf.cast(global_steps_int, dtype)
        warmup_steps_float = tf.cast(warmup_steps_int, dtype)

        warmup_percent_done = global_steps_float / warmup_steps_float
        warmup_learning_rate = learning_rate * warmup_percent_done

        is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype)
        learning_rate = ((1.0 - is_warmup) * learning_rate +
                         is_warmup * warmup_learning_rate)

    learning_rate = mtf.import_fully_replicated(mesh,
                                                learning_rate,
                                                mtf.Shape([]),
                                                name="learning_rate")
    scalar_summary("lr", learning_rate)

    if optimizer_name.lower() == "adam":
        optimizer = mtf.optimize.AdamWeightDecayOptimizer(
            learning_rate=learning_rate,
            weight_decay_rate=params.get("weight_decay", 0.0),
            beta_1=params.get("beta_1", 0.9),
            beta_2=params.get("beta_2", 0.999),
            epsilon=params.get("epsilon", 1e-6),
            exclude_from_weight_decay=["norm", "bias"])
    elif optimizer_name.lower() == "adafactor":
        optimizer = mtf.optimize.AdafactorOptimizer(
            learning_rate=learning_rate,
            decay_rate=params.get("weight_decay", 0.0),
            beta1=params.get("beta_1", 0.9),
            epsilon1=params.get("epsilon_1", 1e-30),
            epsilon2=params.get("epsilon_2", 1e-3))
    else:
        raise ValueError(f"{optimizer_name} not recognized")

    if gradient_clipping is not None:
        (var_grads_fp, _) = clip_by_global_norm(var_grads_fp,
                                                clip_norm=clip_value)
    update_ops = optimizer.apply_grads(var_grads_fp, valid_vars)
    return learning_rate, update_ops, var_grads_fp
Exemplo n.º 18
0
 def scalar(v, dtype):
     return mtf.constant(mesh, v, shape=[], dtype=dtype)
Exemplo n.º 19
0
    def test_get_indices(self):
        key_size = 2
        n_keys = 3
        product_size = 2
        head_size = 2
        batch = 2
        seq_len = 2
        knn = 2

        n_key_dim = mtf.Dimension("n_keys", n_keys)
        key_dim = mtf.Dimension("key", key_size // 2)
        seq_dim = mtf.Dimension("length", seq_len)
        batch_dim = mtf.Dimension("batch", batch)
        head_dim = mtf.Dimension("n_heads", head_size)
        product_dim = mtf.Dimension("product_key", product_size)
        knn_dim = mtf.Dimension("knn", knn)

        query_shape = mtf.Shape(
            [batch_dim, seq_dim, head_dim, product_dim, key_dim])
        keys_shape = mtf.Shape([head_dim, product_dim, n_key_dim, key_dim])
        query = mtf.ones(self.mesh, query_shape)

        keys_vals = [
            [
                [[4], [1], [2]],
                [[2], [-1], [2]],
            ],
            [
                [[1], [2], [5]],
                [[6], [1], [4]],
            ],
        ]
        # h1:
        #   First scores:
        #   [4, 2]
        #   [2, 2]
        #   Cartesian added scores:
        #   [6, 6]
        #   Indices:
        #   [0, 2]    [0*n_k + 0, 0*n_k + 2]
        # h2:
        #   First scores:
        #   [5, 2]
        #   [6, 4]
        #   Cartesian added scores:
        #   [11, 9]
        #   Indices:
        #   [6, 8]   [2*n_k+0, 2*n_k+2]
        expected_scores = np.broadcast_to(np.array([[6, 6], [11, 9]]),
                                          [batch, seq_len, head_size, knn])
        expected_indices = np.broadcast_to(np.array([[0, 2], [6, 8]]),
                                           [batch, seq_len, head_size, knn])

        keys = mtf.constant(self.mesh, keys_vals, keys_shape)

        pkm = memory_layers.ProductKeyValueMemory(key_size, n_keys, head_size,
                                                  knn)
        mtf_scores, mtf_indices = pkm.get_indices(keys, query)

        # Shapes.
        expected_shape = mtf.Shape([batch_dim, seq_dim, head_dim, knn_dim])
        self.assertEqual(expected_shape, mtf_scores.shape)
        self.assertEqual(expected_shape, mtf_indices.shape)

        # Values
        lowering_s, scores = self._export_to_tf_tensor(mtf_scores)
        lowering_i, indices = self._export_to_tf_tensor(mtf_indices)
        self.evaluate(tf.global_variables_initializer())
        self.evaluate(lowering_s.copy_masters_to_slices())
        self.evaluate(lowering_i.copy_masters_to_slices())
        scores, indices = self.evaluate([scores, indices])

        self.assertAllEqual(expected_scores, scores)
        self.assertAllEqual(expected_indices, indices)
        def body_fn(position, ids, *states):
            """One step in the decode loop."""
            inputs_this_step = mtf.gather(ids, position - 1, length_dim)
            if self.attribute_embedding:
                attributes_this_step = mtf.gather(attributes, position - 1,
                                                  length_dim)
            else:
                attributes_this_step = None
            # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim))
            context_incremental = Context(
                model=self,
                mesh=inputs.mesh,
                batch_dims=batch_dims,
                length_dim=length_dim,
                variable_dtype=variable_dtype,
                mode="incremental",
                position=position,
                states=states,
                new_states=[],
                sequence_id=sequence_id,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                constant_states=constant_states,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs,
                write_priority=write_priority,
                read_priority=position,
                inputs=inputs_this_step,
                encoder_inputs=encoder_inputs)

            with tf.variable_scope(self.name, reuse=True):
                logits = self._call_internal(context_incremental,
                                             inputs_this_step,
                                             attributes=attributes_this_step,
                                             z=z)
                if never_end:
                    logits += mtf.one_hot(mtf.constant(logits.mesh,
                                                       stop_at_token,
                                                       dtype=tf.int32),
                                          self.output_vocab_dim,
                                          on_value=-1e9,
                                          off_value=0.0,
                                          dtype=logits.dtype)

            # TBD whether this should be before or after never_end:
            # Note for adding top_p sampling in the future, in other code bases, the
            # option to apply temperature is done before the top-k truncation. This
            # implementation does this in the opposite order. For top-k this doesn't
            # matter, but for top_p it will.
            if sampling_keep_top_k != -1:
                if sampling_keep_top_k <= 0:
                    raise ValueError(
                        "sampling_keep_top_k must either be -1 or positive.")
                k_largest = mtf.nth_largest_element(
                    logits,
                    n=sampling_keep_top_k,
                    reduced_dim=self.output_vocab_dim)
                logits = mtf.where(mtf.less_equal(logits, k_largest),
                                   mtf.ones_like(logits) * -1e6, logits)

            ids_this_step = mtf.sample_with_temperature(
                logits, self.output_vocab_dim, temperature)
            new_position = position + 1
            new_ids = ids + ids_this_step * mtf.one_hot(
                position, length_dim, dtype=tf.int32)
            return [new_position, new_ids] + context_incremental.new_states
    def beam_search(self,
                    inputs,
                    decode_length,
                    dst_attributes=None,
                    variable_dtype=mtf.VariableDType(tf.float32),
                    encoder_output=None,
                    encoder_sequence_id=None,
                    encoder_inputs=None,
                    alpha=0.6,
                    shared_params=None,
                    encoder_layer_outputs=None,
                    z=None):
        """Beam search.
        Args:
          inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim,
            length_dim].#
          decode_length: an int32 mtf scalar.  Maximum decode length.
          attributes: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]
                                          ([<batch_dims>]
                                           [<batch_dims>, beam_dim]).
          variable_dtype: a mtf.VariableDType
          encoder_output: an optional Tensor
          encoder_sequence_id: an optional Tensor
          encoder_inputs: an optional Tensor
          alpha: a floating point value (length bonus)
          shared_params: an optional dictionary
          encoder_layer_outputs: optional - readonly list of tensor activations when
            decoding, one per each input layer + the embedding layer
        Returns:
          a Tensor with shape [<batch_dims>, beam_dim, length_dim]
        """
        attributes = dst_attributes
        if not self.autoregressive:
            raise ValueError("must be autoregressive")

        batch_dims = inputs.shape.dims[:-2]
        if len(batch_dims) != 1:
            raise NotImplementedError(
                "beam search supports exactly one batch dimension.")
        beam_dim = inputs.shape.dims[-2]
        length_dim = inputs.shape.dims[-1]
        length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
        initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal(
            inputs, 0)),
                                          reduced_dim=length_dim)
        sequence_id = 1 if encoder_sequence_id is not None else None

        if self.input_full_attention:
            # This only makes sense in the case of beam search with given partial
            # sequences, which is not yet implemented.
            # TODO(noam): implement
            raise NotImplementedError(
                "Beam search for language models not yet implemented")
        else:
            read_priority = write_priority = length_range

        context_first_part = Context(
            model=self,
            mesh=inputs.mesh,
            batch_dims=batch_dims + [beam_dim],
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="first_part",
            position=length_range,
            position_is_default=True,
            new_states=[],
            initial_position=initial_position,
            sequence_id=sequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            constant_states=[],
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=inputs,
            encoder_inputs=encoder_inputs)

        shifted_inputs = mtf.shift(inputs,
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
        with tf.variable_scope(self.name):
            logits = self._call_internal(context_first_part,
                                         shifted_inputs,
                                         attributes=attributes,
                                         z=z)
        del logits
        # There are no partial targets.
        # Replace initial states by zeros to avoid computing them.
        initial_states = [
            mtf.zeros_like(t) for t in context_first_part.new_states
        ]
        constant_states = context_first_part.constant_states

        def logits_fn(step_num, ids, states):
            """logits_fn for mtf.beam_search.beam_search()."""
            inputs_this_step = mtf.gather(ids, step_num - 1, length_dim)

            if self.attribute_embedding:
                attributes_this_step = mtf.gather(attributes, step_num - 1,
                                                  length_dim)
            else:
                attributes_this_step = None

            context_incremental = Context(
                model=self,
                mesh=inputs.mesh,
                batch_dims=batch_dims + [beam_dim],
                length_dim=length_dim,
                variable_dtype=variable_dtype,
                mode="incremental",
                position=step_num,
                states=states,
                new_states=[],
                sequence_id=sequence_id,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                constant_states=constant_states,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs,
                write_priority=write_priority,
                read_priority=step_num,
                inputs=inputs_this_step,
                encoder_inputs=encoder_inputs)
            with tf.variable_scope(self.name, reuse=True):
                logits = self._call_internal(context_incremental,
                                             inputs_this_step,
                                             attributes=attributes_this_step,
                                             z=z)
            return mtf.to_float(logits), context_incremental.new_states

        beams, unused_scores = mtf.beam_search.beam_search(
            logits_fn,
            inputs,
            alpha,
            states=initial_states,
            decode_length=decode_length,
            use_tpu=True,
            dtype=tf.float32,
            mesh_shape=self.mesh_shape,
            layout=self.layout)
        return mtf.gather(beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32),
                          beam_dim)
    def act_layer(self, context, x, mask):
        """Build a Universal Transformer ACT layer."""
        state = x
        act_max_steps = self.act_max_steps
        threshold = 1.0 - self.act_epsilon
        state_shape_static = state.shape.dims

        state_slice = slice(0, 3)
        if self.act_type == "global":
            state_slice = slice(0, 2)

        # Dynamic shape for update tensors below
        update_shape = state_shape_static[state_slice]

        # Halting probabilities (p_t^n in the paper)
        halting_probability = mtf.zeros(context.mesh,
                                        update_shape,
                                        dtype=context.activation_dtype)

        # Remainders (R(t) in the paper)
        remainders = mtf.zeros(context.mesh,
                               update_shape,
                               dtype=context.activation_dtype)

        # Number of updates performed (N(t) in the paper)
        n_updates = mtf.zeros(context.mesh,
                              update_shape,
                              dtype=context.activation_dtype)

        # Previous cell states (s_t in the paper)
        previous_state = mtf.zeros_like(state)
        step = mtf.constant(context.mesh, 0, dtype=tf.int32)

        def ut_function(state, step, halting_probability, remainders,
                        n_updates, previous_state):
            """implements act (position-wise halting).

      Args:
        state: 3-D Tensor: [batch_size, length, channel]
        step: indicates number of steps taken so far
        halting_probability: halting probability
        remainders: act remainders
        n_updates: act n_updates
        previous_state: previous state

      Returns:
        transformed_state: transformed state
        step: step+1
        halting_probability: halting probability
        remainders: act remainders
        n_updates: act n_updates
        new_state: new state
      """
            state = self.step_preprocess(context, state, step)

            if self.act_type == "random":
                # random as halting probability
                p = mtf.random_uniform(context.mesh,
                                       shape=halting_probability.shape.dims,
                                       dtype=context.variable_dtype)
            else:
                last_dim_name = state.shape.dimension_names[-1]
                new_dims = [mtf.Dimension(last_dim_name, 1)]
                with tf.variable_scope("sigmoid_activation_for_pondering",
                                       reuse=tf.AUTO_REUSE):
                    p = mtf.layers.dense(state,
                                         variable_dtype=context.variable_dtype,
                                         reduced_dims=[state.shape.dims[-1]],
                                         new_dims=new_dims,
                                         activation=mtf.sigmoid,
                                         use_bias=True)
                    if self.act_type == "global":
                        # average over all positions (as a global halting prob)
                        p = mtf.reduce_mean(p, reduced_dim=p.shape.dims[1])
                        p = mtf.squeeze(p)
                    else:
                        # maintain position-wise probabilities
                        new_shape = p.shape.dims[:-1]
                        p = mtf.reshape(p, new_shape)
            # Mask for inputs which have not halted yet
            still_running = mtf.cast(mtf.less(halting_probability, 1.0),
                                     context.activation_dtype)

            # Mask of inputs which halted at this step
            new_halted = mtf.cast(
                mtf.greater(halting_probability + p * still_running,
                            threshold),
                context.activation_dtype) * still_running
            # Mask of inputs which haven't halted, and didn't halt this step
            still_running = mtf.cast(
                mtf.less_equal(halting_probability + p * still_running,
                               threshold),
                context.activation_dtype) * still_running

            # Add the halting probability for this step to the halting
            # probabilities for those input which haven't halted yet
            halting_probability += p * still_running

            # Compute remainders for the inputs which halted at this step
            remainders += new_halted * (1 - halting_probability)

            # Add the remainders to those inputs which halted at this step
            halting_probability += new_halted * remainders

            # Increment n_updates for all inputs which are still running
            n_updates += still_running + new_halted

            # Compute the weight to be applied to the new state and output
            # 0 when the input has already halted
            # p when the input hasn't halted yet
            # the remainders when it halted this step
            input_tensor = p * still_running + new_halted * remainders
            update_weights = input_tensor

            # apply transformation on the state
            transformed_state = state

            for _ in range(self.num_inrecurrence_layers):
                transformed_state = self.vanilla_transformer_layer(
                    context, transformed_state, mask)

            # update running part in the weighted state and keep the rest
            new_state = ((transformed_state * update_weights) +
                         (previous_state * (1 - update_weights)))

            if self.act_type == "accumulated":
                # Add in the weighted state
                new_state = (transformed_state *
                             update_weights) + previous_state

            step += 1

            return (transformed_state, step, halting_probability, remainders,
                    n_updates, new_state)

        for _ in range(act_max_steps + 1):
            (state, step, halting_probability, remainders, n_updates,
             previous_state) = ut_function(state, step, halting_probability,
                                           remainders, n_updates,
                                           previous_state)
        ponder_times = n_updates

        mtf.scalar_summary("ponder_times", mtf.reduce_mean(ponder_times))
        return previous_state
    def get_timing_signal_1d(self,
                             context,
                             length,
                             channels,
                             min_timescale=1.0,
                             max_timescale=1.0e4,
                             start_index=0):
        """Gets a bunch of sinusoids of different frequencies.

    Each channel of the input Tensor is incremented by a sinusoid of a different
    frequency and phase.

    This allows attention to learn to use absolute and relative positions.
    Timing signals should be added to some precursors of both the query and the
    memory inputs to attention.

    The use of relative position is possible because sin(x+y) and cos(x+y) can
    be expressed in terms of y, sin(x) and cos(x).

    In particular, we use a geometric sequence of timescales starting with
    min_timescale and ending with max_timescale.  The number of different
    timescales is equal to channels / 2. For each timescale, we
    generate the two sinusoidal signals sin(timestep/timescale) and
    cos(timestep/timescale).  All of these sinusoids are concatenated in
    the channels dimension.

    Args:
      context: mtf context.
      length: a mtf.Dimension, length of timing signal sequence.
      channels: a mtf.Dimension, size of timing embeddings to create.
      The number of different timescales is equal to channels / 2.
      min_timescale: a float
      max_timescale: a float
      start_index: index of first position

    Returns:
      a Tensor of timing signals [1, length, channels]
    """

        position = context.get_position() + start_index
        num_timescales = mtf.constant(context.mesh, channels.size // 2)
        log_timescale_increment = (
            math.log(float(max_timescale) / float(min_timescale)) /
            mtf.maximum(num_timescales - 1, 1))
        channel_dim_name = channels.name
        inv_timescales = (min_timescale * mtf.exp(
            mtf.mtf_range(context.mesh,
                          mtf.Dimension(channel_dim_name, channels.size // 2),
                          context.activation_dtype) * -log_timescale_increment)
                          )

        scaled_time = position * inv_timescales
        # Please note that this slightly differs from the published paper.
        # See a discussion here:
        # https://github.com/tensorflow/tensor2tensor/pull/177
        #    concat_dim_name = scaled_time.shape.dimension_names[1]
        concat_dim_name = channels.name
        signal = mtf.concat(
            [mtf.sin(scaled_time), mtf.cos(scaled_time)],
            concat_dim_name=concat_dim_name)

        if channels.size % 2 != 0:
            raise NotImplementedError("Odd channel size not implemented.")
        new_dims = [mtf.Dimension("expanded", 1)
                    ] + length.shape.dims + channels.shape.dim
        signal = mtf.reshape(signal, mtf.Shape(new_dims))
        return signal
Exemplo n.º 24
0
def create_optimizer(loss,
                     init_lr,
                     num_train_steps,
                     num_warmup_steps,
                     max_optimized_variable_size=None,
                     optimizer="adam",
                     clip_gradients=True):
    """Creates an optimizer training op."""
    global_step = tf.train.get_or_create_global_step()
    mesh = loss.mesh

    if init_lr:
        # Implements linear decay of the learning rate.
        learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)
        learning_rate = tf.train.polynomial_decay(learning_rate,
                                                  global_step,
                                                  num_train_steps,
                                                  end_learning_rate=0.0,
                                                  power=1.0,
                                                  cycle=False)
        # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
        # learning rate will be `global_step/num_warmup_steps * init_lr`.
        if num_warmup_steps:
            global_steps_int = tf.cast(global_step, tf.int32)
            warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

            global_steps_float = tf.cast(global_steps_int, tf.float32)
            warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

            warmup_percent_done = global_steps_float / warmup_steps_float
            warmup_learning_rate = init_lr * warmup_percent_done

            is_warmup = tf.cast(global_steps_int < warmup_steps_int,
                                tf.float32)
            learning_rate = ((1.0 - is_warmup) * learning_rate +
                             is_warmup * warmup_learning_rate)

        mtf_learning_rate = mtf.import_tf_tensor(mesh, learning_rate, [])
    else:
        if optimizer == "adam":
            raise ValueError("Adam does not have a default learning rate")
        learning_rate = None
        mtf_learning_rate = None

    # It is recommended that you use this optimizer for fine tuning, since this
    # is how the model was trained (note that the Adam m/v variables are NOT
    # loaded from init_checkpoint.)
    if optimizer == "adam":
        optimizer = mtf_optimize.AdamWeightDecayOptimizer(
            learning_rate=mtf_learning_rate,
            weight_decay_rate=0.01,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-6,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
    elif optimizer == "adafactor":
        optimizer = mtf_optimize.AdafactorOptimizer(
            learning_rate=learning_rate, min_dim_size_to_factor=32)
    else:
        raise ValueError("unknown optimizer")

    trainable_variables = mesh.graph.trainable_variables
    if max_optimized_variable_size:
        trainable_variables = [
            t for t in trainable_variables
            if t.shape.size <= max_optimized_variable_size
        ]

    var_grads = mtf.gradients([loss],
                              [v.outputs[0] for v in trainable_variables])

    # This is how the model was pre-trained.
    if clip_gradients:
        (var_grads,
         _) = clip_by_global_norm(var_grads,
                                  clip_norm=mtf.constant(mesh,
                                                         1.0,
                                                         dtype=tf.float32))

    update_ops = optimizer.apply_grads(var_grads, trainable_variables)

    return learning_rate, update_ops
Exemplo n.º 25
0
  def _sample(self, features, mesh):
    hparams = self._hparams
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if hparams.transformer_type == "encdec":
      inputs = features["inputs"]
      while len(inputs.shape.as_list()) > 2:
        inputs = tf.squeeze(inputs, axis=2)
      actual_batch_size = tf.shape(inputs)[0]
      actual_length = tf.shape(inputs)[1]
      inputs = tf.pad(
          inputs, [[0, hparams.batch_size - actual_batch_size],
                   [0, hparams.max_length - actual_length]])
      inputs = self._import_to_batch_by_length(
          inputs, "inputs", mesh, hparams)
      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.reshape(positional_embedding_var,
                       mtf.Shape([self.length_dim, self.model_dim])))
      encoder_attention_mask = (
          mtf.layers.attention_mask_ignore_padding(
              inputs, dtype=self.activation_dtype))
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.encoder_layers,
                              self_attention_mask=encoder_attention_mask)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
      encdec_tensors = []
      for layer_num, layer_type in enumerate(hparams.decoder_layers):
        if layer_type == "enc_att":
          with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num):
            q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars(
                mesh, self.heads_dim, self.model_dim,
                self.kv_dim, self.master_dtype, self.slice_dtype,
                self.activation_dtype)
            k = mtf.einsum(
                [encoder_output, k_var],
                mtf.Shape(
                    self.batch_dims + [self.heads_dim,
                                       self.memory_length_dim, self.kv_dim]))
            v = mtf.einsum(
                [encoder_output, v_var],
                mtf.Shape(
                    self.batch_dims + [self.heads_dim,
                                       self.memory_length_dim, self.kv_dim]))
          encdec_tensors.append((q_var, o_var, k, v))
        else:
          encdec_tensors.append(None)
      partial_targets = None
    elif hparams.transformer_type == "decoder":
      encdec_tensors = None
      encoder_output = None
      encoder_attention_mask = None
      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs", None)
      if partial_targets is None:
        partial_targets = features.get("targets", None)
      if partial_targets is not None:
        partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
        partial_targets = tf.to_int32(partial_targets)
        partial_targets_batch = tf.shape(partial_targets)[0]
        partial_targets_length = tf.shape(partial_targets)[1]
        partial_targets = tf.pad(
            partial_targets, [[0, hparams.batch_size - partial_targets_batch],
                              [0, hparams.max_length - partial_targets_length]])
        partial_targets = self._import_to_batch_by_length(
            partial_targets, "partial_targets", mesh, hparams)
    else:
      raise ValueError(
          "hparams.model_type = %s not yet supported"
          % hparams.transformer_type)

    local_attention_window = mtf.Dimension(
        "local_attention_window", hparams.local_attention_window_size)
    if hparams.beam_size == 1:
      ids_shape = mtf.Shape(self.batch_dims + [self.length_dim])
      kv_shape = mtf.Shape(self.batch_dims +
                           [self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
      local_kv_shape = mtf.Shape(self.batch_dims +
                                 [self.heads_dim,
                                  local_attention_window, self.kv_dim])
    else:
      beam_dim = mtf.Dimension("beam", hparams.beam_size)
      ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim])
      kv_shape = mtf.Shape(self.batch_dims +
                           [beam_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
      local_kv_shape = mtf.Shape(self.batch_dims +
                                 [beam_dim, self.heads_dim,
                                  local_attention_window, self.kv_dim])

    initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
    initial_states = []
    for layer in hparams.decoder_layers:
      if layer == "att":
        initial_states.extend(
            [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2)
      elif layer == "local_att":
        initial_states.extend(
            [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2)

    def logits_fn(step_num, ids, states):
      """Produce logits for this step, and new states."""
      ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
      x = (mtf.gather(targets_embedding_var, ids_this_step,
                      self.targets_vocab_dim) +
           mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
      with tf.variable_scope("decoder"):
        x, new_states = self._layer_stack(
            x,
            hparams.decoder_layers,
            encdec_attention_mask=encoder_attention_mask,
            step_num=step_num,
            encdec_tensors=encdec_tensors,
            states=states)
      logits = mtf.matmul(x, softmax_var)
      return logits, new_states

    if hparams.beam_size == 1:
      temperature = (0.0 if hparams.sampling_method == "argmax"
                     else hparams.sampling_temp)
      return mtf.beam_search.greedy_decode(
          logits_fn,
          initial_ids,
          temperature=temperature,
          initial_states=initial_states,
          forced_ids=partial_targets,
          use_tpu=hparams.use_tpu)
    else:
      if hparams.transformer_type == "encdec":
        input_length = mtf.reduce_sum(
            mtf.to_float(mtf.cast(inputs, tf.bool)),
            reduced_dim=self.length_dim)
        max_input_length = mtf.reduce_max(input_length)
        decode_length = mtf.cast(
            max_input_length * hparams.decode_length_multiplier
            + hparams.decode_length_constant, tf.int32)
      else:
        decode_length = None
      beams, unused_scores = mtf.beam_search.beam_search(
          logits_fn,
          initial_ids,
          hparams.alpha,
          states=initial_states,
          decode_length=decode_length,
          use_tpu=hparams.use_tpu,
          dtype=self.activation_dtype)
      return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
Exemplo n.º 26
0
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None):
    """Creates and returns an optimizer training op."""
    global_step = tf.train.get_or_create_global_step()

    learning_rate = tf.constant(value=params["lr"],
                                shape=[],
                                dtype=variable_dtype.slice_dtype)
    clip_value = mtf.constant(mesh,
                              params["gradient_clipping"],
                              dtype=variable_dtype.slice_dtype)

    if inp_var_grads is None:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in mesh.graph.trainable_variables])
    else:
        var_grads = inp_var_grads

    # Cast to full precision
    var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads]

    # decrease LR to final lr (lr*0.1) by this step - defaults to train_steps
    end_step = params.get("lr_decay_end", params["train_steps"])

    if params["lr_decay"] == "linear":
        learning_rate = tf.train.polynomial_decay(
            learning_rate,
            global_step,
            end_step,
            end_learning_rate=params["lr"] *
            0.1,  # Decrease to 10% of initial LR according to GPT-3 paper
            power=1.0,
            cycle=False)
    elif params["lr_decay"] == "cosine":
        learning_rate = tf.train.cosine_decay(
            learning_rate,
            global_step,
            end_step,
            alpha=0.1  # Alpha is min lr value as a fraction of init lr.
        )

    if params["warmup_steps"] > 0:
        global_steps_int = tf.cast(global_step, tf.int32)
        warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32)

        dtype = variable_dtype.slice_dtype

        global_steps_float = tf.cast(global_steps_int, dtype)
        warmup_steps_float = tf.cast(warmup_steps_int, dtype)

        warmup_percent_done = global_steps_float / warmup_steps_float
        warmup_learning_rate = learning_rate * warmup_percent_done

        is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype)
        learning_rate = ((1.0 - is_warmup) * learning_rate +
                         is_warmup * warmup_learning_rate)

    learning_rate = mtf.import_fully_replicated(mesh,
                                                learning_rate,
                                                mtf.Shape([]),
                                                name="learning_rate")
    mtf.scalar_summary("lr", learning_rate)

    if params["opt_name"].lower() == "adam":
        optimizer = AdamWeightDecayOptimizer(
            learning_rate=learning_rate,
            weight_decay_rate=params["weight_decay"],
            beta_1=params["beta1"],
            beta_2=params["beta2"],
            epsilon=params["epsilon"],
            exclude_from_weight_decay=["norm", "bias"],
            variable_dtype=variable_dtype)
    else:
        optimizer = mtf.optimize.AdafactorOptimizer(
            learning_rate=params["lr"],
            decay_rate=params["weight_decay"],
            beta1=params["beta1"],
            epsilon1=params["ada_epsilon1"],
            epsilon2=params["ada_epsilon2"])

    if params["use_tpu"]:
        optimizer = tf.tpu.CrossShardOptimizer(optimizer)

    if params["gradient_clipping"] is not None:
        (var_grads_fp, _) = clip_by_global_norm(var_grads_fp,
                                                clip_norm=clip_value)

    update_ops = optimizer.apply_grads(var_grads_fp,
                                       mesh.graph.trainable_variables)
    return learning_rate, update_ops, var_grads_fp
Exemplo n.º 27
0
def _switch_gating(inputs,
                   outer_expert_dims,
                   experts_dim,
                   expert_capacity_dim,
                   hparams,
                   train,
                   variable_dtype,
                   importance=None,
                   name="switch_gating",
                   num_microbatches=None):
  """Compute a switch top-1 gating with no-token-left behind behavior."""
  # SELECT EXPERT
  if train:
    policy = hparams.moe_rand_1_policy_train
  else:
    policy = hparams.moe_rand_1_policy_eval

  # Input perturbations
  if train and policy == "input_jitter":
    inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter)

  gate_logits = mtf.layers.dense(
      inputs,
      experts_dim,
      use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)
  raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)

  # The internals of this function run in float32.
  #   bfloat16 seems to reduce quality.
  raw_gates = mtf.to_float(raw_gates)

  # Top-k operation
  k_dim = mtf.Dimension("k", hparams.moe_switch_top_k)
  expert_gate, expert_index = mtf.top_k(
      raw_gates, reduced_dim=experts_dim, k_dim=k_dim)
  expert_mask = mtf.one_hot(expert_index, experts_dim)

  # LOAD BALANCING LOSS
  outer_batch_dim = inputs.shape[0]
  batch_dim = inputs.shape[1]
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Logging
  if train:
    entropy = mtf.reduce_sum(
        -raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim)
    batch_entropy = mtf.reduce_mean(entropy)
    mtf.scalar_summary(name + "/entropy", batch_entropy)

    mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
    total_routed = mtf.reduce_sum(mask_count_experts)
    expert_fraction = mtf.to_float(mask_count_experts / total_routed)
    split_fractions = mtf.split(
        expert_fraction,
        split_dim=experts_dim,
        num_or_size_splits=experts_dim.size)
    for fraction in split_fractions:
      mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
                         mtf.reduce_mean(fraction))
    mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

  # COMPUTE ASSIGNMENT TO EXPERT
  # Iteratively route tokens (no-token-left-behind). The idea is to route as
  # many tokens as possible to top-i before then trying top-(i+1).
  top_k_masks = mtf.split(
      expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size)
  top_k_gates = mtf.split(
      expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size)
  top_k_indices = mtf.split(
      expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size)

  # Tensors cumulative values over the iterative process.
  combine_tensor = mtf.constant(
      inputs.mesh,
      value=0,
      shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim])
  cum_tokens = mtf.constant(
      inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim])
  tokens_left_to_route = mtf.constant(
      inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim])

  expert_capacity_float = float(expert_capacity_dim.size)
  for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates,
                                                   top_k_indices):
    top_i_mask = mtf.reshape(
        top_i_mask,
        new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim])
    # Operate only on the unrouted tokens.
    top_i_mask *= tokens_left_to_route

    # Record cumulative number of tokens to each expert across iterations.
    cumulative_tokens_in_expert = cum_tokens + mtf.cumsum(
        top_i_mask, group_size_dim)

    expert_overflow = mtf.to_float(
        mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float))
    output_i_tokens = top_i_mask * expert_overflow

    # Update the cumulative tokens routed to each expert.
    cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim)
    tokens_left_to_route -= (
        mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim))

    # Combine-tensor for this iteration
    output_i_tokens_flat = mtf.reduce_sum(
        output_i_tokens, reduced_dim=experts_dim)
    position_in_expert = cumulative_tokens_in_expert - 1
    top_i_combine_tensor = (
        top_i_gate * output_i_tokens_flat *
        mtf.one_hot(top_i_index, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim))
    combine_tensor += top_i_combine_tensor

  # Match the inputs dtype.
  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)
  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss
Exemplo n.º 28
0
  def beam_search(self,
                  inputs,
                  decode_length,
                  variable_dtype=mtf.VariableDType(tf.float32),
                  encoder_output=None,
                  encoder_sequence_id=None,
                  alpha=0.6,
                  shared_params=None,
                  encoder_layer_outputs=None):
    """Beam search.

    Args:
      inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim,
        length_dim].
      decode_length: an int32 mtf scalar.  Maximum decode length.
      variable_dtype: a mtf.VariableDType
      encoder_output: an optional Tensor
      encoder_sequence_id: an optional Tensor
      alpha: a floating point value (length bonus)
      shared_params: an optional dictionary
      encoder_layer_outputs: optional - readonly list of tensor activations when
        decoding, one per each input layer + the embedding layer

    Returns:
      a Tensor with shape [<batch_dims>, beam_dim, length_dim]
    """
    if not self.autoregressive:
      raise ValueError("must be autoregressive")

    batch_dims = inputs.shape.dims[:-2]
    if len(batch_dims) != 1:
      raise NotImplementedError(
          "beam search supports exactly one batch dimension.")
    beam_dim = inputs.shape.dims[-2]
    length_dim = inputs.shape.dims[-1]
    initial_position = mtf.reduce_sum(
        mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim)
    sequence_id = 1 if encoder_sequence_id is not None else None

    context_first_part = Context(
        mesh=inputs.mesh,
        batch_dims=batch_dims + [beam_dim],
        length_dim=length_dim,
        model_dim=self.model_dim,
        variable_dtype=variable_dtype,
        mode="first_part",
        autoregressive=self.autoregressive,
        new_states=[],
        initial_position=initial_position,
        sequence_id=sequence_id,
        encoder_output=encoder_output,
        encoder_sequence_id=encoder_sequence_id,
        constant_states=[],
        shared_params=shared_params,
        layout=self.layout,
        mesh_shape=self.mesh_shape,
        encoder_layer_outputs=encoder_layer_outputs)

    shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False)
    with tf.variable_scope(self.name):
      logits = self._call_internal(context_first_part, shifted_inputs)
    del logits
    # There are no partial targets.
    # Replace initial states by zeros to avoid computing them.
    initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states]
    constant_states = context_first_part.constant_states

    def logits_fn(step_num, ids, states):
      """logits_fn for mtf.beam_search.beam_search()."""
      context_incremental = Context(
          mesh=inputs.mesh,
          batch_dims=batch_dims + [beam_dim],
          length_dim=length_dim,
          model_dim=self.model_dim,
          variable_dtype=variable_dtype,
          mode="incremental",
          autoregressive=self.autoregressive,
          position=step_num,
          states=states,
          new_states=[],
          sequence_id=sequence_id,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          constant_states=constant_states,
          shared_params=shared_params,
          layout=self.layout,
          mesh_shape=self.mesh_shape,
          encoder_layer_outputs=encoder_layer_outputs)
      inputs_this_step = mtf.gather(ids, step_num - 1, length_dim)
      with tf.variable_scope(self.name, reuse=True):
        logits = self._call_internal(context_incremental, inputs_this_step)
      return mtf.to_float(logits), context_incremental.new_states

    beams, unused_scores = mtf.beam_search.beam_search(
        logits_fn,
        inputs,
        alpha,
        states=initial_states,
        decode_length=decode_length,
        use_tpu=True,
        dtype=tf.float32,
        mesh_shape=self.mesh_shape,
        layout=self.layout)
    return mtf.gather(
        beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
Exemplo n.º 29
0
def get_optimizer(loss, params, summary, inp_var_grads=None):
    """Creates and returns an optimizer training op."""

    global_step = tf.train.get_or_create_global_step()  # get global step
    mesh = loss.mesh  # get mesh info from loss
    graph = mesh.graph  # get graph info from mesh

    if inp_var_grads is None:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
    else:
        var_grads = inp_var_grads

    learning_rate = tf.constant(value=params["lr"], shape=[],
                                dtype=tf.float32)  # grab lr param

    if params["lr_decay"] == "linear":
        learning_rate = tf.train.polynomial_decay(
            learning_rate,
            global_step,
            params["train_steps"],
            end_learning_rate=params["lr"] *
            0.1,  # decrease to 10% of initial LR according to GPT-3 paper
            power=1.0,
            cycle=False,
        )
    elif params["lr_decay"] == "cosine":
        learning_rate = tf.train.cosine_decay(
            learning_rate,
            global_step,
            params["train_steps"],
            alpha=0.1,  # alpha is min lr value as a fraction of init lr.
        )

    if params["warmup_steps"] > 0:
        global_steps_int = tf.cast(global_step, tf.int32)
        warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32)

        global_steps_float = tf.cast(global_steps_int, tf.float32)
        warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

        warmup_percent_done = global_steps_float / warmup_steps_float
        warmup_learning_rate = learning_rate * warmup_percent_done

        is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
        learning_rate = (
            1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate

    summary.scalar("lr", learning_rate)

    if params["opt_name"].lower() == "adam":
        optimizer = mtf.optimize.AdamWeightDecayOptimizer(
            learning_rate=learning_rate,
            weight_decay_rate=params["weight_decay"],
            beta_1=params["beta1"],
            beta_2=params["beta2"],
            epsilon=params["epsilon"],
            exclude_from_weight_decay=["norm", "bias"],
        )
    else:
        optimizer = mtf.optimize.AdafactorOptimizer(
            learning_rate=params["lr"],
            decay_rate=params["weight_decay"],
            beta1=params["beta1"],
            epsilon1=params["ada_epsilon1"],
            epsilon2=params["ada_epsilon2"],
        )

    if params["gradient_clipping"] is not None:
        clip_value = mtf.constant(mesh,
                                  params["gradient_clipping"],
                                  dtype=tf.float32)
        (var_grads, _) = clip_by_global_norm(var_grads, clip_norm=clip_value)

    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)
    return learning_rate, update_ops