예제 #1
0
 def call(self, inputs, **kwargs):
     source, target = inputs
     target_shape = keras.backend.shape(target)
     if keras.backend.image_data_format() == 'channels_first':
         source = backend.transpose(source, (0, 2, 3, 1))
         output = backend.resize_images(source, (target_shape[2], target_shape[3]), method='nearest')
         output = backend.transpose(output, (0, 3, 1, 2))
         return output
     else:
         return backend.resize_images(source, (target_shape[1], target_shape[2]), method='nearest')
예제 #2
0
def get_attn_mask_func(inputs):
    input_mask = inputs
    batch_size = K.shape(input_mask)[0]
    target_len = K.shape(input_mask)[1]
    # 512, ?
    input_mask_trans = K.transpose(input_mask)

    # 1, 512, ?
    # data_mask = input_mask[None]
    data_mask = K.expand_dims(input_mask_trans, axis=0)

    # ?, 0, 2
    # mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz], dtype=tf_float)
    mems_mask = keras.layers.Lambda(
        lambda x: tf.zeros([1, 0, x], dtype=tf.float32))(batch_size)

    # 1, 512, 2
    # data_mask = tf.concat([mems_mask, data_mask], 1)
    data_mask = K.concatenate([mems_mask, data_mask], axis=1)

    # attn_mask = data_mask[:, :, :, None]
    attn_mask = K.expand_dims(data_mask, axis=-1)

    # attn_mask = tf.cast(attn_mask > 0, dtype=tf_float)
    attn_mask = keras.layers.Lambda(
        lambda x: tf.cast(x > 0, dtype=tf.float32))(attn_mask)

    # non_tgt_mask = -tf.eye(512, dtype=tf.float32)
    non_tgt_mask = keras.layers.Lambda(lambda x: -tf.eye(x, dtype=tf.float32))(
        target_len)

    # non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=tf_float),non_tgt_mask], axis=-1)
    tmp = keras.layers.Lambda(lambda x: tf.zeros([x, 0], dtype=tf.float32))(
        target_len)
    non_tgt_mask = K.concatenate([tmp, non_tgt_mask], axis=-1)

    # non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=tf_float)
    non_tgt_mask = K.expand_dims(non_tgt_mask, axis=-1)
    non_tgt_mask = K.expand_dims(non_tgt_mask, axis=-1)
    tmp2 = K.greater(attn_mask + non_tgt_mask, 0)
    attn_mask = K.cast(tmp2, tf.float32)
    return attn_mask
예제 #3
0
def build_xlnet_for_tf_estimator(inputs,
                                 num_token,
                                 num_layer,
                                 num_head,
                                 embedding_dim,
                                 attention_head_dim,
                                 feed_forward_dim,
                                 target_len,
                                 is_training,
                                 memory_len=None,
                                 dropout=0.0,
                                 attention_dropout=0.0,
                                 attention_type=None,
                                 shared_biases=True):
    input_ids, input_mask, segment_ids, cls_index, \
    p_mask, start_positions, end_positions, is_impossible = inputs

    attn_mask = get_attn_mask(input_mask)

    input_ids_trans = keras.layers.Lambda(lambda x: K.transpose(x))(input_ids)
    token_embed = keras.layers.Embedding(input_dim=num_token,
                                         output_dim=embedding_dim,
                                         name='Embed-Token')(input_ids_trans)
    token_embed_dropout = keras.layers.Dropout(rate=dropout,
                                               name='Embed-Token-Dropout')(
                                                   token_embed,
                                                   training=is_training)

    pos_emb = get_pos_emb([input_ids_trans, token_embed])
    pos_emb = keras.layers.Dropout(rate=dropout)(pos_emb, training=is_training)

    initializer = keras.initializers.get('normal')
    initializer.__setattr__("stddev", 0.02)

    segment_ids_trans = keras.layers.Lambda(lambda x: K.transpose(x))(
        segment_ids)
    segment_mat, segment_embed = RelativeSegmentEmbedding(
        num_layer=num_layer,
        num_head=num_head,
        attention_dim=attention_head_dim,
        initializer=initializer,
        name='Embed-Segment',
    )(segment_ids_trans)

    r_w_bias, r_r_bias, r_s_bias = RelativeBias(
        num_layer=num_layer,
        num_head=num_head,
        attention_head_dim=attention_head_dim,
        bias_initializer=initializer,
        name='Relative-Bias',
    )(input_ids_trans)

    content_output = token_embed_dropout
    if FLAGS.short_cut_fake:
        attn_mask = tf.constant(1.0, shape=[512, 512, 1, 1], dtype=np.float32)
        segment_mat = tf.constant(1.0,
                                  shape=[512, 512, 1, 2],
                                  dtype=np.float32)
        pos_emb = tf.constant(1.0, shape=[1024, 1, 1024], dtype=np.float32)
    if FLAGS.short_cut_fuse:
        attn_mask_flat = tf.reshape(attn_mask, [-1])
        segment_mat_flat = tf.reshape(segment_mat, [-1])
        segment_embed_flat = tf.reshape(segment_embed, [-1])
        pos_emb_flat = tf.reshape(pos_emb, [-1])
        r_w_bias_flat = tf.reshape(r_w_bias, [-1])
        r_r_bias_flat = tf.reshape(r_r_bias, [-1])
        r_s_bias_flat = tf.reshape(r_s_bias, [-1])
        fused = tf.concat([attn_mask_flat, segment_mat_flat, segment_embed_flat, \
                  pos_emb_flat, r_w_bias_flat, r_r_bias_flat, r_s_bias_flat], 0)

    for i in range(num_layer):
        attention = RelativeMultiHeadAttention(
            num_head=num_head,
            attention_head_dim=attention_head_dim,
            embedding_dim=embedding_dim,
            dropout=dropout,
            dropatt=attention_dropout,
            is_training=is_training,
            initializer=initializer,
            name='Attention-{}'.format(i + 1),
        )

        attention_add = tf.keras.layers.Add(
            name='Attention-Residual-{}'.format(i + 1))
        attention_layer_norm = LayerNormalization(
            name='Attention-Normal-{}'.format(i + 1))

        feed_forward = FeedForward(feed_forward_dim=feed_forward_dim,
                                   embedding_dim=embedding_dim,
                                   dropout_rate=dropout,
                                   kernel_initializer=initializer,
                                   activation=gelu,
                                   name='FeedForward-{}'.format(i + 1))
        feed_forward_add = tf.keras.layers.Add(
            name='FeedForward-Residual-{}'.format(i + 1))
        feed_forward_layer_norm = LayerNormalization(
            name='FeedForward-Normal-{}'.format(i + 1))

        segment_embed_i = keras.layers.Lambda(lambda x: x[i])(segment_embed)
        r_w_bias_i = keras.layers.Lambda(lambda x: x[i])(r_w_bias)
        r_r_bias_i = keras.layers.Lambda(lambda x: x[i])(r_r_bias)
        r_s_bias_i = keras.layers.Lambda(lambda x: x[i])(r_s_bias)
        if FLAGS.short_cut_fuse:
            attn_mask_flat, segment_mat_flat, segment_embed_flat, \
              pos_emb_flat, r_w_bias_flat, r_r_bias_flat, r_s_bias_flat = \
                tf.split(fused, [512*512*1, 512*512*2, 24*2*1024, \
                         1024*1024, 24*1024, 24*1024, 24*1024], 0)
            attn_mask = tf.reshape(attn_mask_flat, [512, 512, 1, 1])
            segment_mat = tf.reshape(segment_mat_flat, [512, 512, 1, 2])
            segment_embed = tf.reshape(segment_embed_flat, [24, 2, 16, 64])
            pos_emb = tf.reshape(pos_emb_flat, [1024, 1, 1024])
            r_w_bias = tf.reshape(r_w_bias_flat, [24, 16, 64])
            r_r_bias = tf.reshape(r_r_bias_flat, [24, 16, 64])
            r_s_bias = tf.reshape(r_s_bias_flat, [24, 16, 64])
            print(attn_mask, segment_mat, segment_embed, pos_emb, r_w_bias,
                  r_r_bias, r_s_bias)

        def _build_output(query):
            attention_input = query
            _output = attention([
                query, pos_emb, segment_embed_i, segment_mat, r_w_bias_i,
                r_r_bias_i, r_s_bias_i, attn_mask
            ])
            _output = attention_add([attention_input, _output])
            _output = attention_layer_norm(_output)
            feed_forward_input = keras.layers.Lambda(lambda x: K.identity(x))(
                _output)
            _output = feed_forward(_output, training=is_training)
            _output = feed_forward_add([feed_forward_input, _output])
            _output = feed_forward_layer_norm(_output)
            return _output

        content_output = _build_output(content_output)

    output = keras.layers.Dropout(rate=dropout)(content_output,
                                                training=is_training)

    xlnet_loss = XLnetLoss(d_model=embedding_dim,
                           seq_len=target_len,
                           kernel_initializer=initializer,
                           name="XLNET_LOSS")([
                               cls_index, start_positions, end_positions,
                               is_impossible, p_mask, output
                           ])

    return xlnet_loss
예제 #4
0
    def __init__(self,
                 num_token,
                 num_layer,
                 num_head,
                 embedding_dim,
                 attention_head_dim,
                 feed_forward_dim,
                 target_len,
                 is_training,
                 memory_len=None,
                 dropout=0.0,
                 attention_dropout=0.0,
                 attention_type=None,
                 shared_biases=True):
        self.num_layer = num_layer
        self.dropout = dropout
        self.attention_dropout = attention_dropout

        self.token_embed = keras.layers.Embedding(input_dim=num_token,
                                                  output_dim=embedding_dim,
                                                  name='Embed-Token')

        initializer = keras.initializers.get('normal')
        initializer.__setattr__("stddev", 0.02)

        self.segment_ids_trans = keras.layers.Lambda(lambda x: K.transpose(x))
        self.segment_mat_embed = RelativeSegmentEmbedding(
            num_layer=num_layer,
            num_head=num_head,
            attention_dim=attention_head_dim,
            initializer=initializer,
            name='Embed-Segment')

        self.relative_bias = RelativeBias(
            num_layer=num_layer,
            num_head=num_head,
            attention_head_dim=attention_head_dim,
            bias_initializer=initializer,
            name='Relative-Bias')

        self.attention = []
        self.attention_add = []
        self.attention_layer_norm = []
        self.feed_forward = []
        self.feed_forward_add = []
        self.feed_forward_layer_norm = []
        for i in range(num_layer):
            self.attention.append(
                RelativeMultiHeadAttention(
                    num_head=num_head,
                    attention_head_dim=attention_head_dim,
                    embedding_dim=embedding_dim,
                    dropout=dropout,
                    dropatt=attention_dropout,
                    is_training=is_training,
                    initializer=initializer,
                    name='Attention-{}'.format(i + 1),
                ))

            self.attention_add.append(
                tf.keras.layers.Add(name='Attention-Residual-{}'.format(i +
                                                                        1)))
            self.attention_layer_norm.append(
                LayerNormalization(name='Attention-Normal-{}'.format(i + 1)))

            self.feed_forward.append(
                FeedForward(feed_forward_dim=feed_forward_dim,
                            embedding_dim=embedding_dim,
                            dropout_rate=dropout,
                            kernel_initializer=initializer,
                            activation=gelu,
                            name='FeedForward-{}'.format(i + 1)))
            self.feed_forward_add.append(
                tf.keras.layers.Add(name='FeedForward-Residual-{}'.format(i +
                                                                          1)))
            self.feed_forward_layer_norm.append(
                LayerNormalization(name='FeedForward-Normal-{}'.format(i + 1)))
        self.xlnet_loss = XLnetLoss(d_model=embedding_dim,
                                    seq_len=target_len,
                                    kernel_initializer=initializer,
                                    name="XLNET_LOSS")
예제 #5
0
    def build(self, inputs, dep_outputs=None, is_training=True):
        # pipeline num device
        devices = cluster_utils.get_pipeline_devices(FLAGS.pipeline_device_num)
        ndev = len(devices)
        # embedding + dropout + ... + dropout
        nstage = self.num_layer + 3

        def calc_device(i):
            # original stage fn
            idx = int((i + 2) / ((nstage + 1) / ndev + 1))
            split_layer_id = 13 if FLAGS.num_layer == 24 else 19
            # For XLNet-24:
            # stage fn 1: Forward-11 in stage0
            # stage fn 2: Forward-10 in stage0
            # For XLNet-36:
            # stage fn: Forward 17 in stage0
            # 13:(11:13) for xlnet-24
            # 19:(17:19) for xlnet-36
            if i < split_layer_id:
                return 0
            else:
                return 1
            return idx

        dep = None
        device_idx = 0
        if dep_outputs is not None and dep_outputs[device_idx] is not None:
            dep = dep_outputs[device_idx] \
                  if isinstance(dep_outputs[device_idx], list) else [dep_outputs[device_idx]]
        with tf.control_dependencies(dep), tf.device(devices[device_idx]):
            input_ids, input_mask, segment_ids, cls_index, \
            p_mask, start_positions, end_positions, is_impossible = inputs

            # 1MB, [512,512,1,1]
            attn_mask = get_attn_mask(input_mask)

            input_ids_trans = keras.layers.Lambda(lambda x: K.transpose(x))(
                input_ids)
            token_embed = self.token_embed(input_ids_trans)
            segment_ids_trans = self.segment_ids_trans(segment_ids)
            # 2MB [512,512,1,2] 192KB [24, 2, 16, 64]
            segment_mat, segment_embed = self.segment_mat_embed(
                segment_ids_trans)
            # 3*96KB [24, 16, 64]
            r_w_bias, r_r_bias, r_s_bias = self.relative_bias(input_ids_trans)
            token_embed_dropout = keras.layers.Dropout(
                rate=self.dropout,
                name='Embed-Token-Dropout')(token_embed, training=is_training)
            content_output = token_embed_dropout
            pos_emb = get_pos_emb([input_ids_trans, token_embed])
            # 4MB [1024, 1, 1024]
            pos_emb = keras.layers.Dropout(rate=self.dropout)(
                pos_emb, training=is_training)
            if FLAGS.short_cut_fake:
                attn_mask = tf.constant(1.0,
                                        shape=[512, 512, 1, 1],
                                        dtype=np.float32)
                segment_mat = tf.constant(1.0,
                                          shape=[512, 512, 1, 2],
                                          dtype=np.float32)
                pos_emb = tf.constant(1.0,
                                      shape=[1024, 1, 1024],
                                      dtype=np.float32)
            if FLAGS.short_cut_fuse:
                attn_mask_flat = tf.reshape(attn_mask, [-1])
                segment_mat_flat = tf.reshape(segment_mat, [-1])
                segment_embed_flat = tf.reshape(segment_embed, [-1])
                pos_emb_flat = tf.reshape(pos_emb, [-1])
                r_w_bias_flat = tf.reshape(r_w_bias, [-1])
                r_r_bias_flat = tf.reshape(r_r_bias, [-1])
                r_s_bias_flat = tf.reshape(r_s_bias, [-1])
                fused = tf.concat([attn_mask_flat, segment_mat_flat, segment_embed_flat, \
                          pos_emb_flat, r_w_bias_flat, r_r_bias_flat, r_s_bias_flat], 0)
        def _build_output(query, i, pos_emb, segment_mat, segment_embed, \
                          attn_mask, r_w_bias, r_r_bias, r_s_bias):
            segment_embed_i = keras.layers.Lambda(lambda x: x[i])(
                segment_embed)
            r_w_bias_i = keras.layers.Lambda(lambda x: x[i])(r_w_bias)
            r_r_bias_i = keras.layers.Lambda(lambda x: x[i])(r_r_bias)
            r_s_bias_i = keras.layers.Lambda(lambda x: x[i])(r_s_bias)
            attention_input = query
            _output = self.attention[i]([
                query, pos_emb, segment_embed_i, segment_mat, r_w_bias_i,
                r_r_bias_i, r_s_bias_i, attn_mask
            ])
            _output = self.attention_add[i]([attention_input, _output])
            _output = self.attention_layer_norm[i](_output)
            feed_forward_input = keras.layers.Lambda(lambda x: K.identity(x))(
                _output)
            _output = self.feed_forward[i](_output, training=is_training)
            _output = self.feed_forward_add[i]([feed_forward_input, _output])
            _output = self.feed_forward_layer_norm[i](_output)
            return _output

        # output list of all stages
        stage_outputs = []
        # previous output, for the first stage, it is input_ids
        prev_output = input_ids
        # previous device index, init value is 0
        prev_device_idx = 0
        for i in range(self.num_layer):
            layer = i + 2
            device_idx = calc_device(layer)
            dep = None
            boundary = False
            if device_idx != prev_device_idx:
                # current layer cross the stage
                boundary = True
                if dep_outputs is not None and dep_outputs[
                        device_idx] is not None:
                    dep = dep_outputs[device_idx] \
                          if isinstance(dep_outputs[device_idx], list) else [dep_outputs[device_idx]]
                stage_outputs.append(prev_output)
                prev_device_idx = device_idx
            with tf.control_dependencies(dep), tf.device(devices[device_idx]):
                if boundary:
                    if FLAGS.short_cut_fake:
                        attn_mask = tf.constant(1.0,
                                                shape=[512, 512, 1, 1],
                                                dtype=np.float32)
                        segment_mat = tf.constant(1.0,
                                                  shape=[512, 512, 1, 2],
                                                  dtype=np.float32)
                        pos_emb = tf.constant(1.0,
                                              shape=[1024, 1, 1024],
                                              dtype=np.float32)
                    if FLAGS.short_cut_fuse:
                        num_layers = FLAGS.num_layer
                        attn_mask_flat, segment_mat_flat, segment_embed_flat, \
                          pos_emb_flat, r_w_bias_flat, r_r_bias_flat, r_s_bias_flat = \
                            tf.split(fused, [512*512*1, 512*512*2, num_layers*2*1024, \
                                     1024*1024, num_layers*1024, num_layers*1024, num_layers*1024], 0)
                        attn_mask = tf.reshape(attn_mask_flat,
                                               [512, 512, 1, 1])
                        segment_mat = tf.reshape(segment_mat_flat,
                                                 [512, 512, 1, 2])
                        segment_embed = tf.reshape(segment_embed_flat,
                                                   [num_layers, 2, 16, 64])
                        pos_emb = tf.reshape(pos_emb_flat, [1024, 1, 1024])
                        r_w_bias = tf.reshape(r_w_bias_flat,
                                              [num_layers, 16, 64])
                        r_r_bias = tf.reshape(r_r_bias_flat,
                                              [num_layers, 16, 64])
                        r_s_bias = tf.reshape(r_s_bias_flat,
                                              [num_layers, 16, 64])
                        print(attn_mask, segment_mat, segment_embed, pos_emb,
                              r_w_bias, r_r_bias, r_s_bias)
                content_output = _build_output(content_output, i, pos_emb, segment_mat, \
                                segment_embed, attn_mask, r_w_bias, r_r_bias, r_s_bias)
            prev_output = content_output

        # current layer cross the stage
        layer = self.num_layer + 2
        device_idx = calc_device(layer)
        dep = None
        if device_idx != prev_device_idx:
            # current layer cross the stage
            if dep_outputs is not None and dep_outputs[device_idx] is not None:
                dep = dep_outputs[device_idx] \
                      if isinstance(dep_outputs[device_idx], list) else [dep_outputs[device_idx]]
            stage_outputs.append(prev_output)
        with tf.control_dependencies(dep), tf.device(devices[device_idx]):
            output = keras.layers.Dropout(rate=self.dropout)(
                content_output, training=is_training)

            xlnet_loss = self.xlnet_loss([
                cls_index, start_positions, end_positions, is_impossible,
                p_mask, output
            ])
            stage_outputs.append(xlnet_loss)

        return xlnet_loss, stage_outputs
예제 #6
0
    def call(self, inputs, **kwargs):
        cls_index, start_positions, end_positions, is_impossible, p_mask, output = inputs
        # output 512, ?, 1024
        if len(start_positions.shape) == 1:
            start_positions = K.expand_dims(start_positions, axis=-1)
            cls_index = K.expand_dims(cls_index, axis=-1)
            end_positions = K.expand_dims(end_positions, axis=-1)
            is_impossible = K.expand_dims(is_impossible, axis=-1)

        cls_index = K.squeeze(cls_index, -1)
        start_positions = K.squeeze(start_positions, -1)
        end_positions = K.squeeze(end_positions, -1)
        is_impossible = K.squeeze(is_impossible, -1)

        # logit of the start position
        start_logits = self.dense(output)
        start_logits = K.transpose(K.squeeze(start_logits, -1))
        start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask
        start_log_probs = keras.layers.Lambda(lambda x: tf.nn.log_softmax(x, -1))(start_logits_masked)

        # logit of the end position
        start_positions = K.cast(start_positions, dtype=tf.int32)
        # tart_index_1 = K.one_hot(start_positions, self.max_seq_length)
        start_index = keras.layers.Lambda(lambda x: tf.one_hot(x[0],
                                                               x[1], dtype=tf.float32))(
            [start_positions, self.max_seq_length])

        # start_features = tf.einsum("lbh,bl->bh", output, start_index)
        start_features = keras.layers.Lambda(lambda x: tf.einsum("lbh,bl->bh", x[0], x[1]))([output, start_index])
        start_features = K.expand_dims(start_features, 0)
        start_features = K.tile(start_features, [self.max_seq_length, 1, 1])
        tmp_concat = K.concatenate([output, start_features], axis=-1)
        end_logits = self.dense_0(tmp_concat)
        #end_logits = tf.contrib.layers.layer_norm(end_logits,begin_norm_axis=-1)
        end_logits = self.layer_norm(end_logits)

        end_logits = self.dense_1(end_logits)
        end_logits = K.transpose(K.squeeze(end_logits, -1))
        end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask
        # end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
        end_log_probs = keras.layers.Lambda(lambda x: tf.nn.log_softmax(x, -1))(end_logits_masked)

        start_loss = - K.sum(start_log_probs * start_index, axis=-1)
        start_loss = K.mean(start_loss)
        end_positions = K.cast(end_positions, dtype=tf.int32)
        # end_index = K.one_hot(end_positions_squeeze, self.max_seq_length)
        end_index = keras.layers.Lambda(lambda x: tf.one_hot(x[0],
                                                             x[1], dtype=tf.float32))(
            [end_positions, self.max_seq_length])

        end_loss = - K.sum(end_log_probs * end_index, axis=-1)
        end_loss = K.mean(end_loss)
        total_loss = (start_loss + end_loss) * 0.5

        # an additional layer to predict answerability
        cls_index = K.cast(cls_index, dtype=tf.int32)
        # cls_index = K.one_hot(cls_index, self.max_seq_length)
        cls_index = keras.layers.Lambda(lambda x: tf.one_hot(x[0],
                                                             x[1], dtype=tf.float32))(
            [cls_index, self.max_seq_length])

        # cls_feature = tf.einsum("lbh,bl->bh", output, cls_index)
        cls_feature = keras.layers.Lambda(lambda x: tf.einsum("lbh,bl->bh", x[0], x[1]))([output, cls_index])
        # start_p = tf.nn.softmax(start_logits_masked, axis=-1, name="softmax_start")
        start_p = keras.layers.Lambda(lambda x: tf.nn.softmax(x, axis=-1))(start_logits_masked)
        # start_feature = tf.einsum("lbh,bl->bh", output, start_p)
        start_feature = keras.layers.Lambda(lambda x:
                                            tf.einsum("lbh,bl->bh", x[0], x[1]))([output, start_p])

        # ans_feature = tf.concat([start_feature, cls_feature], -1)
        ans_feature = K.concatenate([start_feature, cls_feature], -1)
        ans_feature = self.dense_0_1(ans_feature)
        ans_feature = keras.layers.Dropout(rate=0.1)(ans_feature, training=True)
        cls_logits = self.dense_1_1(ans_feature)
        cls_logits = K.squeeze(cls_logits, -1)

        is_impossible = K.reshape(is_impossible, [-1])
        regression_loss = keras.layers.Lambda(lambda x:
                                              tf.nn.sigmoid_cross_entropy_with_logits(labels=x[0],
                                                                                      logits=x[1]))(
            [is_impossible, cls_logits])
        regression_loss = K.mean(regression_loss)
        total_loss += regression_loss * 0.5
        self.add_loss(total_loss, inputs=True)
        return total_loss