Exemple #1
0
    def _attn(self, inputs, training=False):
        q, k, v, attention_mask, head_mask = inputs
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
        if self.scale:
            dk = tf.cast(shape_list(k)[-1],
                         tf.float32)  # scale attention_scores
            w = w / tf.math.sqrt(dk)

        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
        b = self.causal_attention_mask(nd, ns, dtype=w.dtype)
        b = tf.reshape(b, [1, 1, nd, ns])
        w = w * b - 1e4 * (1 - b)

        if attention_mask is not None:
            # Apply the attention mask
            w = w + attention_mask

        w = tf.nn.softmax(w, axis=-1)
        w = self.attn_dropout(w, training=training)

        # Mask heads if we want to
        if head_mask is not None:
            w = w * head_mask

        outputs = [tf.matmul(w, v)]
        if self.output_attentions:
            outputs.append(w)
        return outputs
    def call(self, seq, training=False):
        seq_fts = self.feature_conv(seq)
        f_1 = self.f1_conv(seq_fts)
        f_2 = self.f2_conv(seq_fts)
        logits = f_1 + tf.transpose(f_2, [0, 2, 1])
        coefs = tf.nn.softmax(tf.nn.leaky_relu(logits))
        vals = tf.matmul(coefs, seq_fts)
        ret = tf.nn.bias_add(vals, self.bias)

        # residual connection
        if self.residual:
            if shape_list(seq)[-1] != shape_list(ret)[-1]:
                ret = ret + self.res_conv(seq)
            else:
                ret = ret + seq
        return self.activation(ret)
Exemple #3
0
    def call(self, inputs, training=False):
        input_ids, position_ids, token_type_ids, inputs_embeds = inputs
        triletter_max_seq_len = shape_list(
            input_ids)[1] // self.triletter_max_letters_in_word
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = self.triletter_embeddings(
            input_ids)  # [N, 12*[20], hidden_size]

        embeddings = tf.reshape(embeddings, [
            -1, triletter_max_seq_len, self.triletter_max_letters_in_word,
            shape_list(embeddings)[-1]
        ])
        embeddings = tf.reshape(
            tf.reduce_sum(embeddings, axis=2),
            [-1, triletter_max_seq_len,
             shape_list(embeddings)[-1]])

        embeddings = embeddings + position_embeddings

        return embeddings
Exemple #4
0
    def call(self, inputs, training=False):
        input_ids, position_ids, token_type_ids, inputs_embeds = inputs
        input_shape = shape_list(input_ids)

        if inputs_embeds is None:
            inputs_embeds = tf.gather(self.word_embeddings, input_ids)
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = inputs_embeds + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings, training=training)
        return embeddings
    def call(self, inputs, training=False):
        if self.num_layers == 0:
            return inputs[0]

        dim0 = shape_list(inputs["bert_1"])[-1]
        dims = [dim0] + self.config.hidden_dims[len(self.config.hidden_dims) - self.num_layers:]

        hidden = inputs

        for layer in range(self.num_layers):
            aggregator = self.aggs[layer]
            next_hidden = {}
            for hop in range(self.num_layers - layer):
                neigh_shape = [-1, self.fanouts[hop], dims[layer]]
                h = aggregator((hidden["bert_" + str(hop)], tf.reshape(hidden["bert_" + str(hop+1)], neigh_shape)))
                next_hidden["bert_" + str(hop)] = h
            hidden = next_hidden

        return hidden["bert_0"]
def hf_mlm_compute_loss(labels, logits):
    """
    Adapted to our dataset with labels with non-masked-tokens marker=0.
    """
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
    masked_lm_active_loss = tf.not_equal(
        tf.reshape(tensor=labels["mlm_long_labels"], shape=(-1, )), 0)
    masked_lm_reduced_logits = tf.boolean_mask(
        tensor=tf.reshape(tensor=logits[0],
                          shape=(-1, shape_list(logits[0])[2])),
        mask=masked_lm_active_loss,
    )
    masked_lm_labels = tf.boolean_mask(tensor=tf.reshape(
        tensor=labels["mlm_long_labels"], shape=(-1, )),
                                       mask=masked_lm_active_loss)
    masked_lm_loss = loss_fn(y_true=masked_lm_labels,
                             y_pred=masked_lm_reduced_logits)
    masked_lm_loss = tf.reduce_mean(input_tensor=masked_lm_loss)
    return masked_lm_loss
    def call(self, inputs, training=False):
        if self.num_layers == 0:
            return inputs["bert_0"]

        dim0 = shape_list(inputs["bert_1"])[-1]
        node_feats = tf.expand_dims(inputs["bert_0"], 1)
        neighbor_feats = tf.reshape(inputs["bert_1"], [-1, self.neighbor_num, dim0])
        seq = tf.concat([node_feats, neighbor_feats], 1)

        for layer, head_num in enumerate(self.config.head_nums):
            hidden = []
            for i in range(head_num):
                hidden_val = self.attention_heads[layer][i](seq)
                hidden.append(hidden_val)
            seq = tf.concat(hidden, -1)

        out = hidden
        out = tf.add_n(out) / self.config.head_nums[-1]
        out = tf.slice(out, [0, 0, 0], [-1, 1, self.config.hidden_dims[-1]])
        return tf.reshape(out, [-1, self.config.hidden_dims[-1]])
    def call(self, inputs, training=False):
        self_vecs, neigh_vecs = inputs

        mask = tf.cast(tf.sign(tf.reduce_max(tf.abs(x), axis=2)), dtype=tf.bool)
        batch_size = shape_list(mask)[0]
        mask = tf.concat([tf.constant(np.ones([batch_size, 1]), dtype=tf.bool), mask[:, 1:]], axis=1)

        rnn_outputs = self.lstm(inputs=neigh_vecs, mask=mask)

        from_neighs = self.neigh_weights(rnn_outputs)
        from_self = self.self_weights(self_vecs)

        if not self.concat:
            output = tf.add_n([from_self, from_neighs])
        else:
            output = tf.concat([from_self, from_neighs], axis=1)

        if self.add_bias:
            output += self.bias

        if self.identity_act: return output
        return self.act(output)
Exemple #9
0
    def call(self,
             inputs,
             past=None,
             attention_mask=None,
             token_type_ids=None,
             position_ids=None,
             head_mask=None,
             inputs_embeds=None,
             mc_token_ids=None,
             training=False):
        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            past = inputs[1] if len(inputs) > 1 else past
            attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
            token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
            position_ids = inputs[4] if len(inputs) > 4 else position_ids
            head_mask = inputs[5] if len(inputs) > 5 else head_mask
            inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
            mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
            assert len(inputs) <= 8, "Too many inputs."
        elif isinstance(inputs, dict):
            input_ids = inputs.get('input_ids')
            past = inputs.get('past', past)
            attention_mask = inputs.get('attention_mask', attention_mask)
            token_type_ids = inputs.get('token_type_ids', token_type_ids)
            position_ids = inputs.get('position_ids', position_ids)
            head_mask = inputs.get('head_mask', head_mask)
            inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
            mc_token_ids = inputs.get('mc_token_ids', mc_token_ids)
            assert len(inputs) <= 8, "Too many inputs."
        else:
            input_ids = inputs

        if input_ids is not None:
            input_shapes = shape_list(input_ids)
        else:
            input_shapes = shape_list(inputs_embeds)[:-1]

        seq_length = input_shapes[-1]

        flat_input_ids = tf.reshape(
            input_ids, (-1, seq_length)) if input_ids is not None else None
        flat_attention_mask = tf.reshape(
            attention_mask,
            (-1, seq_length)) if attention_mask is not None else None
        flat_token_type_ids = tf.reshape(
            token_type_ids,
            (-1, seq_length)) if token_type_ids is not None else None
        flat_position_ids = tf.reshape(
            position_ids,
            (-1, seq_length)) if position_ids is not None else None

        flat_inputs = [
            flat_input_ids, past, flat_attention_mask, flat_token_type_ids,
            flat_position_ids, head_mask, inputs_embeds
        ]

        transformer_outputs = self.transformer(flat_inputs, training=training)
        hidden_states = transformer_outputs[0]

        hidden_states = tf.reshape(
            hidden_states, input_shapes + shape_list(hidden_states)[-1:])

        lm_logits = self.transformer.wte(hidden_states, mode="linear")
        mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids],
                                              training=training)

        mc_logits = tf.squeeze(mc_logits, axis=-1)

        outputs = (lm_logits, mc_logits) + transformer_outputs[1:]

        return outputs  # lm logits, mc logits, presents, (all hidden_states), (attentions)
Exemple #10
0
    def call(self,
             inputs,
             past=None,
             attention_mask=None,
             token_type_ids=None,
             position_ids=None,
             head_mask=None,
             inputs_embeds=None,
             training=False):
        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            past = inputs[1] if len(inputs) > 1 else past
            attention_mask = inputs[2] if len(inputs) > 2 else attention_mask
            token_type_ids = inputs[3] if len(inputs) > 3 else token_type_ids
            position_ids = inputs[4] if len(inputs) > 4 else position_ids
            head_mask = inputs[5] if len(inputs) > 5 else head_mask
            inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
            assert len(inputs) <= 7, "Too many inputs."
        elif isinstance(inputs, dict):
            input_ids = inputs.get('input_ids')
            past = inputs.get('past', past)
            attention_mask = inputs.get('attention_mask', attention_mask)
            token_type_ids = inputs.get('token_type_ids', token_type_ids)
            position_ids = inputs.get('position_ids', position_ids)
            head_mask = inputs.get('head_mask', head_mask)
            inputs_embeds = inputs.get('inputs_embeds', inputs_embeds)
            assert len(inputs) <= 7, "Too many inputs."
        else:
            input_ids = inputs

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input_shape = shape_list(input_ids)
            input_ids = tf.reshape(input_ids, [-1, input_shape[-1]])
        elif inputs_embeds is not None:
            input_shape = shape_list(inputs_embeds)[:-1]
        else:
            raise ValueError(
                "You have to specify either input_ids or inputs_embeds")

        if past is None:
            past_length = 0
            past = [None] * len(self.h)
        else:
            past_length = shape_list(past[0][0])[-2]
        if position_ids is None:
            position_ids = tf.range(past_length,
                                    input_shape[-1] + past_length,
                                    dtype=tf.int32)[tf.newaxis, :]

        if attention_mask is not None:
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.

            attention_mask = tf.cast(attention_mask, tf.float32)
            attention_mask = (1.0 - attention_mask) * -10000.0
        else:
            attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if not head_mask is None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.num_hidden_layers
            # head_mask = tf.constant([0] * self.num_hidden_layers)

        position_ids = tf.reshape(position_ids,
                                  [-1, shape_list(position_ids)[-1]])

        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids, mode='embedding')
        position_embeds = self.wpe(position_ids)
        if token_type_ids is not None:
            token_type_ids = tf.reshape(
                token_type_ids, [-1, shape_list(token_type_ids)[-1]])
            token_type_embeds = self.wte(token_type_ids, mode='embedding')
        else:
            token_type_embeds = 0
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
        hidden_states = self.drop(hidden_states, training=training)

        output_shape = input_shape + [shape_list(hidden_states)[-1]]

        presents = ()
        all_attentions = []
        all_hidden_states = ()
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (tf.reshape(
                    hidden_states, output_shape), )

            outputs = block(
                [hidden_states, layer_past, attention_mask, head_mask[i]],
                training=training)

            hidden_states, present = outputs[:2]
            presents = presents + (present, )

            if self.output_attentions:
                all_attentions.append(outputs[2])

        hidden_states = self.ln_f(hidden_states)

        hidden_states = tf.reshape(hidden_states, output_shape)
        # Add last hidden state
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states, )

        outputs = (hidden_states, presents)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states, )
        if self.output_attentions:
            # let the number of heads free (-1) so we can extract attention even after head pruning
            attention_output_shape = input_shape[:-1] + [-1] + shape_list(
                all_attentions[0])[-2:]
            all_attentions = tuple(
                tf.reshape(t, attention_output_shape) for t in all_attentions)
            outputs = outputs + (all_attentions, )
        return outputs  # last hidden state, presents, (all hidden_states), (attentions)
Exemple #11
0
 def split_heads(self, x):
     x_shape = shape_list(x)
     new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
     x = tf.reshape(x, new_x_shape)
     return tf.transpose(
         x, (0, 2, 1, 3))  # (batch, head, seq_length, head_features)
Exemple #12
0
 def merge_heads(self, x):
     x = tf.transpose(x, [0, 2, 1, 3])
     x_shape = shape_list(x)
     new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
     return tf.reshape(x, new_x_shape)
Exemple #13
0
    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        training=False,
    ):
        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
            output_attentions = inputs[6] if len(
                inputs) > 6 else output_attentions
            output_hidden_states = inputs[7] if len(
                inputs) > 7 else output_hidden_states
            assert len(inputs) <= 8, "Too many inputs."
        elif isinstance(inputs, (dict, BatchEncoding)):
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
            output_attentions = inputs.get("output_attentions",
                                           output_attentions)
            output_hidden_states = inputs.get("output_hidden_states",
                                              output_hidden_states)
            assert len(inputs) <= 8, "Too many inputs."
        else:
            input_ids = inputs

        output_attentions = output_attentions if output_attentions is not None else self.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input_shape = shape_list(input_ids)
        elif inputs_embeds is not None:
            input_shape = shape_list(inputs_embeds)[:-1]
        else:
            raise ValueError(
                "You have to specify either input_ids or inputs_embeds")

        if attention_mask is None:
            if type(self.embeddings) == TriletterEmbeddings:
                attention_mask = tf.ones([
                    input_shape[0], input_shape[1] //
                    self.embeddings.triletter_max_letters_in_word
                ])
            else:
                attention_mask = tf.fill(input_shape, 1)
        if token_type_ids is None:
            if type(self.embeddings) == TriletterEmbeddings:
                token_type_ids = tf.zeros([
                    input_shape[0], input_shape[1] //
                    self.embeddings.triletter_max_letters_in_word
                ])
            else:
                token_type_ids = tf.fill(input_shape, 0)
        if position_ids is None:
            if type(self.embeddings) == TriletterEmbeddings:
                position_ids = (
                    tf.range(input_shape[1] //
                             self.embeddings.triletter_max_letters_in_word,
                             dtype=tf.int32) + 1)[tf.newaxis, :]
            else:
                position_ids = tf.range(int(input_shape[1]),
                                        dtype=tf.int32)[tf.newaxis, :]
            position_ids = tf.where(attention_mask == 0,
                                    tf.zeros_like(position_ids), position_ids)

        extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
        extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        if head_mask is not None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.num_hidden_layers

        embedding_output = self.embeddings(
            [input_ids, position_ids, token_type_ids, inputs_embeds],
            training=training)
        encoder_outputs = self.encoder(
            [
                embedding_output, extended_attention_mask, head_mask,
                output_attentions, output_hidden_states
            ],
            training=training,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

        outputs = (
            sequence_output,
            pooled_output,
        ) + encoder_outputs[
            1:]  # add hidden_states and attentions if they are here

        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)
    def call(self, inputs, training=False):
        if self.num_layers == 0:
            return inputs[0]

        hidden = inputs
        for layer in range(self.num_layers):
            next_hidden = []
            for hop in range(self.num_layers - l):
                neighbor = tf.reshape(hidden["bert_" + str(hop+1)], [-1, self.fanouts[hop] * shape_list(hidden["bert_" + str(hop+1)])[-1]])
                seq = tf.concat([hidden["bert_" + str(hop)], neighbor], axis=-1)
                h = self.dense_layers[layer](seq, training=training)
                next_hidden.append(h)
            hidden = next_hidden
        return hidden[0]