示例#1
0
    def _sample(self, features, mesh):
        hparams = self._hparams
        (inputs_embedding_var, targets_embedding_var, softmax_var,
         positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
        if self.has_input:
            inputs = features["inputs"]
            while len(inputs.shape.as_list()) > 2:
                inputs = tf.squeeze(inputs, axis=2)
            actual_batch_size = tf.shape(inputs)[0]
            actual_length = tf.shape(inputs)[1]
            inputs = tf.pad(inputs,
                            [[0, hparams.batch_size - actual_batch_size],
                             [0, hparams.max_length - actual_length]])
            inputs = self._import_to_batch_by_length(inputs, "inputs", mesh,
                                                     hparams)
            x = (mtf.gather(inputs_embedding_var, inputs,
                            self.inputs_vocab_dim) +
                 mtf.reshape(positional_embedding_var,
                             mtf.Shape([self.length_dim, self.model_dim])))
            encoder_attention_mask = (mtf_layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))
            with tf.variable_scope("encoder"):
                x = self._layer_stack(
                    x,
                    hparams.num_encoder_layers,
                    self_attention_mask=encoder_attention_mask)
            encoder_output = mtf.rename_dimension(x, self.length_dim.name,
                                                  self.memory_length_dim.name)
            encdec_tensors = []
            for layer_num in xrange(hparams.num_decoder_layers):
                with tf.variable_scope("decoder/layer_%d/encdec_attention" %
                                       layer_num):
                    q_var, k_var, v_var, o_var = mtf_layers.multihead_attention_vars(
                        mesh, self.heads_dim, self.model_dim, self.kv_dim,
                        self.activation_dtype)
                    k = mtf.einsum([encoder_output, k_var],
                                   mtf.Shape([
                                       self.batch_dim, self.heads_dim,
                                       self.memory_length_dim, self.kv_dim
                                   ]))
                    v = mtf.einsum([encoder_output, v_var],
                                   mtf.Shape([
                                       self.batch_dim, self.heads_dim,
                                       self.memory_length_dim, self.kv_dim
                                   ]))
                encdec_tensors.append((q_var, o_var, k, v))
            partial_targets = None
        else:
            encdec_tensors = None
            encoder_output = None
            encoder_attention_mask = None
            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get("inputs", None)
            if partial_targets is None:
                partial_targets = features.get("targets", None)
            if partial_targets is not None:
                partial_targets = common_layers.expand_squeeze_to_nd(
                    partial_targets, 2)
                partial_targets = tf.to_int32(partial_targets)
                partial_targets_batch = tf.shape(partial_targets)[0]
                partial_targets_length = tf.shape(partial_targets)[1]
                partial_targets = tf.pad(
                    partial_targets,
                    [[0, hparams.batch_size - partial_targets_batch],
                     [0, hparams.max_length - partial_targets_length]])
                partial_targets = self._import_to_batch_by_length(
                    partial_targets, "partial_targets", mesh, hparams)

        if hparams.beam_size == 1:
            ids_shape = mtf.Shape([self.batch_dim, self.length_dim])
            kv_shape = mtf.Shape([
                self.batch_dim, self.heads_dim, self.memory_length_dim,
                self.kv_dim
            ])
        else:
            beam_dim = mtf.Dimension("beam", hparams.beam_size)
            ids_shape = mtf.Shape([self.batch_dim, beam_dim, self.length_dim])
            kv_shape = mtf.Shape([
                self.batch_dim, beam_dim, self.heads_dim,
                self.memory_length_dim, self.kv_dim
            ])

        initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
        initial_kv_states = (
            [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] *
            (2 * hparams.num_decoder_layers))

        def logits_fn(step_num, ids, states):
            """Produce logits for this step, and new states."""
            self_attention_k = states[:hparams.num_decoder_layers]
            self_attention_v = states[hparams.num_decoder_layers:]
            ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
            x = (mtf.gather(targets_embedding_var, ids_this_step,
                            self.targets_vocab_dim) +
                 mtf.gather(positional_embedding_var, step_num,
                            self.max_length_dim))
            with tf.variable_scope("decoder"):
                x, new_self_attention_k, new_self_attention_v = (
                    self._decoder_layer_stack_incremental(
                        x,
                        step_num,
                        encdec_tensors,
                        self_attention_k,
                        self_attention_v,
                        encdec_attention_mask=encoder_attention_mask))
            logits = mtf.matmul(x, softmax_var)
            return logits, new_self_attention_k + new_self_attention_v

        if hparams.beam_size == 1:
            temperature = (0.0 if hparams.sampling_method == "argmax" else
                           hparams.sampling_temp)
            return mtf_beam_search.greedy_decode(
                logits_fn,
                initial_ids,
                temperature=temperature,
                initial_states=initial_kv_states,
                forced_ids=partial_targets,
                use_tpu=hparams.use_tpu)
        else:
            if self.has_input:
                input_length = mtf.reduce_sum(mtf.to_float(
                    mtf.cast(inputs, tf.bool)),
                                              reduced_dim=self.length_dim)
                max_input_length = mtf.reduce_max(input_length)
                decode_length = mtf.cast(
                    max_input_length * hparams.decode_length_multiplier +
                    hparams.decode_length_constant, tf.int32)
            else:
                decode_length = None
            beams, unused_scores = mtf_beam_search.beam_search(
                logits_fn,
                initial_ids,
                hparams.alpha,
                states=initial_kv_states,
                decode_length=decode_length,
                use_tpu=hparams.use_tpu)
            return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32),
                              beam_dim)
示例#2
0
    def _mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        hparams = self._hparams
        targets = tf.to_int32(features["targets"])
        if len(targets.get_shape()) > 2:
            tf.logging.info("targets = %s" % targets)
            targets = tf.squeeze(targets, [2, 3])
        # pad targets to max_length
        def pad_to_max_length(x):
            extra_length = hparams.max_length - tf.shape(x)[1]
            x = tf.pad(x, [[0, 0], [0, extra_length]])
            x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
            return x

        targets = pad_to_max_length(targets)
        for key in [
                "targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"
        ]:
            if key in features:
                features[key] = pad_to_max_length(features[key])
        shifted_targets = common_layers.shift_right_2d(targets)

        targets = self._import_to_batch_by_length(targets, "targets", mesh,
                                                  hparams)
        shifted_targets = self._import_to_batch_by_length(
            shifted_targets, "shifted_targets", mesh, hparams)

        if "targets_segmentation" in features:
            # "Packed" dataset - keep the examples from seeing each other.
            targets_segmentation = self._import_to_batch_by_length(
                features["targets_segmentation"], "targets_segmentation", mesh,
                hparams)
            targets_position = self._import_to_batch_by_length(
                features["targets_position"], "targets_position", mesh,
                hparams)
            decoder_self_attention_mask = (
                mtf_layers.attention_mask_autoregressive(
                    targets_position, dtype=self.activation_dtype) +
                mtf_layers.attention_mask_same_segment(
                    targets_segmentation, dtype=self.activation_dtype))
        else:
            targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
            decoder_self_attention_mask = mtf_layers.attention_mask_autoregressive(
                targets_position, dtype=self.activation_dtype)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(
                x,
                keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
                noise_shape=mtf.Shape([self.batch_dim, self.model_dim]))

        extra_losses = []
        (inputs_embedding_var, targets_embedding_var, softmax_var,
         positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
        if self.has_input:
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = pad_to_max_length(inputs)
            inputs = self._import_to_batch_by_length(inputs, "inputs", mesh,
                                                     hparams)
            if "inputs_segmentation" in features:
                # "Packed" dataset - keep the examples from seeing each other.
                inputs_segmentation = self._import_to_batch_by_length(
                    features["inputs_segmentation"], "inputs_segmentation",
                    mesh, hparams)
                inputs_position = self._import_to_batch_by_length(
                    features["inputs_position"], "inputs_position", mesh,
                    hparams)
                encoder_self_attention_mask = (
                    mtf_layers.attention_mask_same_segment(
                        inputs_segmentation, dtype=self.activation_dtype))
                encoder_decoder_attention_mask = (
                    mtf_layers.attention_mask_same_segment(
                        targets_segmentation,
                        inputs_segmentation,
                        dtype=self.activation_dtype))
            else:
                inputs_position = mtf.range(mesh,
                                            self.length_dim,
                                            dtype=tf.int32)
                encoder_self_attention_mask = (
                    mtf_layers.attention_mask_ignore_padding(
                        inputs, dtype=self.activation_dtype))
                encoder_decoder_attention_mask = encoder_self_attention_mask

            x = (mtf.gather(inputs_embedding_var, inputs,
                            self.inputs_vocab_dim) +
                 mtf.gather(positional_embedding_var, inputs_position,
                            self.max_length_dim))
            x = layer_prepostprocess_dropout(x)
            with tf.variable_scope("encoder"):
                x = self._layer_stack(
                    x,
                    hparams.num_encoder_layers,
                    self_attention_mask=encoder_self_attention_mask,
                    losses=extra_losses)
            encoder_output = mtf.rename_dimension(x, self.length_dim.name,
                                                  self.memory_length_dim.name)
        else:
            encoder_output = None
            encoder_decoder_attention_mask = None

        # DECODER
        x = (mtf.gather(targets_embedding_var, shifted_targets,
                        self.targets_vocab_dim) +
             mtf.gather(positional_embedding_var, targets_position,
                        self.max_length_dim))
        x = layer_prepostprocess_dropout(x)

        # Decoder
        with tf.variable_scope("decoder"):
            x = self._layer_stack(
                x,
                hparams.num_decoder_layers,
                encoder_output=encoder_output,
                self_attention_mask=decoder_self_attention_mask,
                encdec_attention_mask=encoder_decoder_attention_mask,
                losses=extra_losses)
        logits = mtf.matmul(x, softmax_var)
        off_value = hparams.label_smoothing / self._targets_vocab_size
        on_value = 1.0 - hparams.label_smoothing + off_value
        soft_targets = mtf.one_hot(targets,
                                   self.targets_vocab_dim,
                                   on_value=on_value,
                                   off_value=off_value,
                                   dtype=self.activation_dtype)
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, self.targets_vocab_dim)
        weights = mtf_layers.weights_nonzero(targets,
                                             dtype=self.activation_dtype)
        loss = mtf.reduce_mean(loss * weights)
        for l in extra_losses:
            loss += l
        return logits, loss