Example #1
0
    def _init_graph(self) -> None:
        self.seq_lengths = tf.reduce_sum(self.y_masks_ph, axis=1)

        self.bert = BertModel(config=self.bert_config,
                              is_training=self.is_train_ph,
                              input_ids=self.input_ids_ph,
                              input_mask=self.input_masks_ph,
                              token_type_ids=self.token_types_ph,
                              use_one_hot_embeddings=False)

        with tf.variable_scope('ner'):
            layer_weights = tf.get_variable('layer_weights_',
                                            shape=len(self.encoder_layer_ids),
                                            initializer=tf.ones_initializer(),
                                            trainable=True)
            layer_mask = tf.ones_like(layer_weights)
            layer_mask = tf.nn.dropout(layer_mask, self.encoder_keep_prob_ph)
            layer_weights *= layer_mask
            # to prevent zero division
            mask_sum = tf.maximum(tf.reduce_sum(layer_mask), 1.0)
            layer_weights = tf.unstack(layer_weights / mask_sum)
            # TODO: may be stack and reduce_sum is faster
            units = sum(w * l
                        for w, l in zip(layer_weights, self.encoder_layers()))
            units = tf.nn.dropout(units, keep_prob=self.keep_prob_ph)
        return units
Example #2
0
    def _init_graph(self):
        self._init_placeholders()

        self.bert = BertModel(config=self.bert_config,
                              is_training=self.is_train_ph,
                              input_ids=self.input_ids_ph,
                              input_mask=self.input_masks_ph,
                              token_type_ids=self.token_types_ph,
                              use_one_hot_embeddings=False)

        encoder_layers = [
            self.bert.all_encoder_layers[i] for i in self.encoder_layer_ids
        ]

        with tf.variable_scope('ner'):
            output_layer = sum(encoder_layers) / len(encoder_layers)
            output_layer = tf.nn.dropout(output_layer,
                                         keep_prob=self.keep_prob_ph)

            logits = tf.layers.dense(output_layer,
                                     units=self.n_tags,
                                     name="output_dense")

            self.y_predictions = tf.argmax(logits, -1)
            self.y_probas = tf.nn.softmax(logits, axis=2)

        with tf.variable_scope("loss"):
            y_mask = tf.cast(self.input_masks_ph, tf.float32)
            self.loss = tf.losses.sparse_softmax_cross_entropy(
                labels=self.y_ph, logits=logits, weights=y_mask)
    def _init_graph(self):
        self._init_placeholders()

        self.bert = BertModel(
            config=self.bert_config,
            is_training=self.is_train_ph,
            input_ids=self.input_ids_ph,
            input_mask=self.input_masks_ph,
            token_type_ids=self.token_types_ph,
            use_one_hot_embeddings=False,
        )

        output_layer = self.bert.get_pooled_output()
        hidden_size = output_layer.shape[-1].value

        output_weights = tf.get_variable(
            "output_weights", [self.n_classes, hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))

        output_bias = tf.get_variable("output_bias", [self.n_classes],
                                      initializer=tf.zeros_initializer())

        with tf.variable_scope("loss"):
            output_layer = tf.nn.dropout(output_layer,
                                         keep_prob=self.keep_prob_ph)
            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

            if self.one_hot_labels:
                one_hot_labels = self.y_ph
            else:
                one_hot_labels = tf.one_hot(self.y_ph,
                                            depth=self.n_classes,
                                            dtype=tf.float32)

            self.y_predictions = tf.argmax(logits, axis=-1)
            if not self.multilabel:
                log_probs = tf.nn.log_softmax(logits, axis=-1)
                self.y_probas = tf.nn.softmax(logits, axis=-1)
                per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                                  axis=-1)
                self.loss = tf.reduce_mean(per_example_loss)
            else:
                self.y_probas = tf.nn.sigmoid(logits)
                self.loss = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=one_hot_labels, logits=logits))
Example #4
0
    def _init_graph(self):
        self._init_placeholders()
        with tf.variable_scope("model"):
            self.bert = BertModel(
                config=self.bert_config,
                is_training=self.is_train_ph,
                input_ids=self.input_ids_ph,
                input_mask=self.input_masks_ph,
                token_type_ids=self.token_types_ph,
                use_one_hot_embeddings=False,
            )

        output_layer_a = self.bert.get_pooled_output()

        with tf.variable_scope("loss"):
            with tf.variable_scope("loss"):
                self.loss = tf.contrib.losses.metric_learning.npairs_loss(
                    self.y_ph, output_layer_a, output_layer_a)
                self.y_probas = output_layer_a
Example #5
0
    def _init_graph(self):
        self._init_placeholders()

        with tf.variable_scope("model"):
            model_a = BertModel(config=self.bert_config,
                                is_training=self.is_train_ph,
                                input_ids=self.input_ids_a_ph,
                                input_mask=self.input_masks_a_ph,
                                token_type_ids=self.token_types_a_ph,
                                use_one_hot_embeddings=False)

        with tf.variable_scope("model", reuse=True):
            model_b = BertModel(config=self.bert_config,
                                is_training=self.is_train_ph,
                                input_ids=self.input_ids_b_ph,
                                input_mask=self.input_masks_b_ph,
                                token_type_ids=self.token_types_b_ph,
                                use_one_hot_embeddings=False)

        output_layer_a = model_a.get_pooled_output()
        output_layer_b = model_b.get_pooled_output()

        with tf.variable_scope("loss"):
            output_layer_a = tf.nn.dropout(output_layer_a,
                                           keep_prob=self.keep_prob_ph)
            output_layer_b = tf.nn.dropout(output_layer_b,
                                           keep_prob=self.keep_prob_ph)
            self.loss = tf.contrib.losses.metric_learning.npairs_loss(
                self.y_ph, output_layer_a, output_layer_b)
            logits = tf.multiply(output_layer_a, output_layer_b)
            self.y_probas = tf.reduce_sum(logits, 1)
            self.pooled_out = output_layer_a
    def _init_graph(self):
        self._init_placeholders()

        self.bert = BertModel(config=self.bert_config,
                              is_training=self.is_train_ph,
                              input_ids=self.input_ids_ph,
                              input_mask=self.input_masks_ph,
                              token_type_ids=self.token_types_ph,
                              use_one_hot_embeddings=False,
                              )
        # next sentence prediction head
        with tf.variable_scope("cls/seq_relationship"):
            output_weights = tf.get_variable(
                "output_weights",
                shape=[2, self.bert_config.hidden_size],
                initializer=create_initializer(self.bert_config.initializer_range))
            output_bias = tf.get_variable(
                "output_bias", shape=[2], initializer=tf.zeros_initializer())

        nsp_logits = tf.matmul(self.bert.get_pooled_output(), output_weights, transpose_b=True)
        nsp_logits = tf.nn.bias_add(nsp_logits, output_bias)
        self.nsp_probs = tf.nn.softmax(nsp_logits, axis=-1)
import tensorflow as tf

from bert_dp.modeling import BertConfig, BertModel
from deeppavlov.models.preprocessors.bert_preprocessor import BertPreprocessor

bert_config = BertConfig.from_json_file(
    './bg_cs_pl_ru_cased_L-12_H-768_A-12/bert_config.json')

input_ids = tf.placeholder(shape=(None, None), dtype=tf.int32)
input_mask = tf.placeholder(shape=(None, None), dtype=tf.int32)
token_type_ids = tf.placeholder(shape=(None, None), dtype=tf.int32)

bert = BertModel(config=bert_config,
                 is_training=False,
                 input_ids=input_ids,
                 input_mask=input_mask,
                 token_type_ids=token_type_ids,
                 use_one_hot_embeddings=False)

preprocessor = BertPreprocessor(
    vocab_file='./bg_cs_pl_ru_cased_L-12_H-768_A-12/vocab.txt',
    do_lower_case=False,
    max_seq_length=512)

with tf.Session() as sess:

    # Load model
    tf.train.Saver().restore(
        sess, './bg_cs_pl_ru_cased_L-12_H-768_A-12/bert_model.ckpt')

    # Get predictions
Example #8
0
class BertSQuADModel(LRScheduledTFModel):
    """Bert-based model for SQuAD-like problem setting:
    It predicts start and end position of answer for given question and context.

    [CLS] token is used as no_answer. If model selects [CLS] token as most probable
    answer, it means that there is no answer in given context.

    Start and end position of answer are predicted by linear transformation
    of Bert outputs.

    Args:
        bert_config_file: path to Bert configuration file
        keep_prob: dropout keep_prob for non-Bert layers
        attention_probs_keep_prob: keep_prob for Bert self-attention layers
        hidden_keep_prob: keep_prob for Bert hidden layers
        optimizer: name of tf.train.* optimizer or None for `AdamWeightDecayOptimizer`
        weight_decay_rate: L2 weight decay for `AdamWeightDecayOptimizer`
        pretrained_bert: pretrained Bert checkpoint
        min_learning_rate: min value of learning rate if learning rate decay is used
    """

    def __init__(self, bert_config_file: str,
                 keep_prob: float,
                 attention_probs_keep_prob: Optional[float] = None,
                 hidden_keep_prob: Optional[float] = None,
                 optimizer: Optional[str] = None,
                 weight_decay_rate: Optional[float] = 0.01,
                 pretrained_bert: Optional[str] = None,
                 min_learning_rate: float = 1e-06, **kwargs) -> None:
        super().__init__(**kwargs)

        self.min_learning_rate = min_learning_rate
        self.keep_prob = keep_prob
        self.optimizer = optimizer
        self.weight_decay_rate = weight_decay_rate

        self.bert_config = BertConfig.from_json_file(str(expand_path(bert_config_file)))

        if attention_probs_keep_prob is not None:
            self.bert_config.attention_probs_dropout_prob = 1.0 - attention_probs_keep_prob
        if hidden_keep_prob is not None:
            self.bert_config.hidden_dropout_prob = 1.0 - hidden_keep_prob

        self.sess_config = tf.ConfigProto(allow_soft_placement=True)
        self.sess_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=self.sess_config)

        self._init_graph()

        self._init_optimizer()

        self.sess.run(tf.global_variables_initializer())

        if pretrained_bert is not None:
            pretrained_bert = str(expand_path(pretrained_bert))

            if tf.train.checkpoint_exists(pretrained_bert) \
                    and not tf.train.checkpoint_exists(str(self.load_path.resolve())):
                logger.info('[initializing model with Bert from {}]'.format(pretrained_bert))
                var_list = self._get_saveable_variables(
                    exclude_scopes=('Optimizer', 'learning_rate', 'momentum', 'squad'))
                saver = tf.train.Saver(var_list)
                saver.restore(self.sess, pretrained_bert)

        if self.load_path is not None:
            self.load()

    def _init_graph(self):
        self._init_placeholders()

        seq_len = tf.shape(self.input_ids_ph)[-1]
        self.y_st = tf.one_hot(self.y_st_ph, depth=seq_len)
        self.y_end = tf.one_hot(self.y_end_ph, depth=seq_len)

        self.bert = BertModel(config=self.bert_config,
                              is_training=self.is_train_ph,
                              input_ids=self.input_ids_ph,
                              input_mask=self.input_masks_ph,
                              token_type_ids=self.token_types_ph,
                              use_one_hot_embeddings=False,
                              )

        last_layer = self.bert.get_sequence_output()
        hidden_size = last_layer.get_shape().as_list()[-1]
        bs = tf.shape(last_layer)[0]

        with tf.variable_scope('squad'):
            output_weights = tf.get_variable('output_weights', [2, hidden_size],
                                             initializer=tf.truncated_normal_initializer(stddev=0.02))
            output_bias = tf.get_variable('output_bias', [2], initializer=tf.zeros_initializer())

            last_layer_rs = tf.reshape(last_layer, [-1, hidden_size])

            logits = tf.matmul(last_layer_rs, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            logits = tf.reshape(logits, [bs, -1, 2])
            logits = tf.transpose(logits, [2, 0, 1])

            logits_st, logits_end = tf.unstack(logits, axis=0)

            logit_mask = self.token_types_ph
            # [CLS] token is used as no answer
            mask = tf.concat([tf.ones((bs, 1), dtype=tf.int32), tf.zeros((bs, seq_len-1), dtype=tf.int32)], axis=-1)
            logit_mask = logit_mask + mask

            logits_st = softmax_mask(logits_st, logit_mask)
            logits_end = softmax_mask(logits_end, logit_mask)
            start_probs = tf.nn.softmax(logits_st)
            end_probs = tf.nn.softmax(logits_end)

            outer = tf.matmul(tf.expand_dims(start_probs, axis=2), tf.expand_dims(end_probs, axis=1))
            outer_logits = tf.exp(tf.expand_dims(logits_st, axis=2) + tf.expand_dims(logits_end, axis=1))

            context_max_len = tf.reduce_max(tf.reduce_sum(self.token_types_ph, axis=1))

            max_ans_length = tf.cast(tf.minimum(20, context_max_len), tf.int64)
            outer = tf.matrix_band_part(outer, 0, max_ans_length)
            outer_logits = tf.matrix_band_part(outer_logits, 0, max_ans_length)

            self.yp_score = 1 - tf.nn.softmax(logits_st)[:, 0] * tf.nn.softmax(logits_end)[:, 0]

            self.start_probs = start_probs
            self.end_probs = end_probs
            self.start_pred = tf.argmax(tf.reduce_max(outer, axis=2), axis=1)
            self.end_pred = tf.argmax(tf.reduce_max(outer, axis=1), axis=1)
            self.yp_logits = tf.reduce_max(tf.reduce_max(outer_logits, axis=2), axis=1)

        with tf.variable_scope("loss"):
            loss_st = tf.nn.softmax_cross_entropy_with_logits(logits=logits_st, labels=self.y_st)
            loss_end = tf.nn.softmax_cross_entropy_with_logits(logits=logits_end, labels=self.y_end)
            self.loss = tf.reduce_mean(loss_st + loss_end)

    def _init_placeholders(self):
        self.input_ids_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='ids_ph')
        self.input_masks_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='masks_ph')
        self.token_types_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='token_types_ph')

        self.y_st_ph = tf.placeholder(shape=(None,), dtype=tf.int32, name='y_st_ph')
        self.y_end_ph = tf.placeholder(shape=(None,), dtype=tf.int32, name='y_end_ph')

        self.learning_rate_ph = tf.placeholder_with_default(0.0, shape=[], name='learning_rate_ph')
        self.keep_prob_ph = tf.placeholder_with_default(1.0, shape=[], name='keep_prob_ph')
        self.is_train_ph = tf.placeholder_with_default(False, shape=[], name='is_train_ph')

    def _init_optimizer(self):
        with tf.variable_scope('Optimizer'):
            self.global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32,
                                               initializer=tf.constant_initializer(0), trainable=False)
            # default optimizer for Bert is Adam with fixed L2 regularization
            if self.optimizer is None:

                self.train_op = self.get_train_op(self.loss, learning_rate=self.learning_rate_ph,
                                                  optimizer=AdamWeightDecayOptimizer,
                                                  weight_decay_rate=self.weight_decay_rate,
                                                  beta_1=0.9,
                                                  beta_2=0.999,
                                                  epsilon=1e-6,
                                                  exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]
                                                  )
            else:
                self.train_op = self.get_train_op(self.loss, learning_rate=self.learning_rate_ph)

            if self.optimizer is None:
                new_global_step = self.global_step + 1
                self.train_op = tf.group(self.train_op, [self.global_step.assign(new_global_step)])

    def _build_feed_dict(self, input_ids, input_masks, token_types, y_st=None, y_end=None):
        feed_dict = {
            self.input_ids_ph: input_ids,
            self.input_masks_ph: input_masks,
            self.token_types_ph: token_types,
        }
        if y_st is not None and y_end is not None:
            feed_dict.update({
                self.y_st_ph: y_st,
                self.y_end_ph: y_end,
                self.learning_rate_ph: max(self.get_learning_rate(), self.min_learning_rate),
                self.keep_prob_ph: self.keep_prob,
                self.is_train_ph: True,
            })

        return feed_dict

    def train_on_batch(self, features: List[InputFeatures], y_st: List[List[int]], y_end: List[List[int]]) -> Dict:
        """Train model on given batch.
        This method calls train_op using features and labels from y_st and y_end

        Args:
            features: batch of InputFeatures instances
            y_st: batch of lists of ground truth answer start positions
            y_end: batch of lists of ground truth answer end positions

        Returns:
            dict with loss and learning_rate values

        """
        input_ids = [f.input_ids for f in features]
        input_masks = [f.input_mask for f in features]
        input_type_ids = [f.input_type_ids for f in features]

        y_st = [x[0] for x in y_st]
        y_end = [x[0] for x in y_end]

        feed_dict = self._build_feed_dict(input_ids, input_masks, input_type_ids, y_st, y_end)

        _, loss = self.sess.run([self.train_op, self.loss], feed_dict=feed_dict)
        return {'loss': loss, 'learning_rate': feed_dict[self.learning_rate_ph]}

    def __call__(self, features: List[InputFeatures]) -> Tuple[List[int], List[int], List[float], List[float]]:
        """get predictions using features as input

        Args:
            features: batch of InputFeatures instances

        Returns:
            predictions: start, end positions, logits for answer and no_answer score

        """
        input_ids = [f.input_ids for f in features]
        input_masks = [f.input_mask for f in features]
        input_type_ids = [f.input_type_ids for f in features]

        feed_dict = self._build_feed_dict(input_ids, input_masks, input_type_ids)
        st, end, logits, scores = self.sess.run([self.start_pred, self.end_pred, self.yp_logits, self.yp_score], feed_dict=feed_dict)
        return st, end, logits.tolist(), scores.tolist()
Example #9
0
    def _init_graph(self):
        self._init_placeholders()

        seq_len = tf.shape(self.input_ids_ph)[-1]
        self.y_st = tf.one_hot(self.y_st_ph, depth=seq_len)
        self.y_end = tf.one_hot(self.y_end_ph, depth=seq_len)

        self.bert = BertModel(config=self.bert_config,
                              is_training=self.is_train_ph,
                              input_ids=self.input_ids_ph,
                              input_mask=self.input_masks_ph,
                              token_type_ids=self.token_types_ph,
                              use_one_hot_embeddings=False,
                              )

        last_layer = self.bert.get_sequence_output()
        hidden_size = last_layer.get_shape().as_list()[-1]
        bs = tf.shape(last_layer)[0]

        with tf.variable_scope('squad'):
            output_weights = tf.get_variable('output_weights', [2, hidden_size],
                                             initializer=tf.truncated_normal_initializer(stddev=0.02))
            output_bias = tf.get_variable('output_bias', [2], initializer=tf.zeros_initializer())

            last_layer_rs = tf.reshape(last_layer, [-1, hidden_size])

            logits = tf.matmul(last_layer_rs, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            logits = tf.reshape(logits, [bs, -1, 2])
            logits = tf.transpose(logits, [2, 0, 1])

            logits_st, logits_end = tf.unstack(logits, axis=0)

            logit_mask = self.token_types_ph
            # [CLS] token is used as no answer
            mask = tf.concat([tf.ones((bs, 1), dtype=tf.int32), tf.zeros((bs, seq_len-1), dtype=tf.int32)], axis=-1)
            logit_mask = logit_mask + mask

            logits_st = softmax_mask(logits_st, logit_mask)
            logits_end = softmax_mask(logits_end, logit_mask)
            start_probs = tf.nn.softmax(logits_st)
            end_probs = tf.nn.softmax(logits_end)

            outer = tf.matmul(tf.expand_dims(start_probs, axis=2), tf.expand_dims(end_probs, axis=1))
            outer_logits = tf.exp(tf.expand_dims(logits_st, axis=2) + tf.expand_dims(logits_end, axis=1))

            context_max_len = tf.reduce_max(tf.reduce_sum(self.token_types_ph, axis=1))

            max_ans_length = tf.cast(tf.minimum(20, context_max_len), tf.int64)
            outer = tf.matrix_band_part(outer, 0, max_ans_length)
            outer_logits = tf.matrix_band_part(outer_logits, 0, max_ans_length)

            self.yp_score = 1 - tf.nn.softmax(logits_st)[:, 0] * tf.nn.softmax(logits_end)[:, 0]

            self.start_probs = start_probs
            self.end_probs = end_probs
            self.start_pred = tf.argmax(tf.reduce_max(outer, axis=2), axis=1)
            self.end_pred = tf.argmax(tf.reduce_max(outer, axis=1), axis=1)
            self.yp_logits = tf.reduce_max(tf.reduce_max(outer_logits, axis=2), axis=1)

        with tf.variable_scope("loss"):
            loss_st = tf.nn.softmax_cross_entropy_with_logits(logits=logits_st, labels=self.y_st)
            loss_end = tf.nn.softmax_cross_entropy_with_logits(logits=logits_end, labels=self.y_end)
            self.loss = tf.reduce_mean(loss_st + loss_end)
class BertClassifierModel(LRScheduledTFModel):
    """Bert-based model for text classification.

    It uses output from [CLS] token and predicts labels using linear transformation.

    Args:
        bert_config_file: path to Bert configuration file
        n_classes: number of classes
        keep_prob: dropout keep_prob for non-Bert layers
        one_hot_labels: set True if one-hot encoding for labels is used
        multilabel: set True if it is multi-label classification
        return_probas: set True if return class probabilites instead of most probable label needed
        attention_probs_keep_prob: keep_prob for Bert self-attention layers
        hidden_keep_prob: keep_prob for Bert hidden layers
        optimizer: name of tf.train.* optimizer or None for `AdamWeightDecayOptimizer`
        num_warmup_steps:
        weight_decay_rate: L2 weight decay for `AdamWeightDecayOptimizer`
        pretrained_bert: pretrained Bert checkpoint
        min_learning_rate: min value of learning rate if learning rate decay is used
    """

    # TODO: add warmup
    # TODO: add head-only pre-training
    def __init__(self,
                 bert_config_file,
                 n_classes,
                 keep_prob,
                 one_hot_labels=False,
                 multilabel=False,
                 return_probas=False,
                 attention_probs_keep_prob=None,
                 hidden_keep_prob=None,
                 optimizer=None,
                 num_warmup_steps=None,
                 weight_decay_rate=0.01,
                 pretrained_bert=None,
                 min_learning_rate=1e-06,
                 **kwargs) -> None:
        super().__init__(**kwargs)

        self.return_probas = return_probas
        self.n_classes = n_classes
        self.min_learning_rate = min_learning_rate
        self.keep_prob = keep_prob
        self.one_hot_labels = one_hot_labels
        self.multilabel = multilabel
        self.optimizer = optimizer
        self.num_warmup_steps = num_warmup_steps
        self.weight_decay_rate = weight_decay_rate

        if self.multilabel and not self.one_hot_labels:
            raise RuntimeError(
                'Use one-hot encoded labels for multilabel classification!')

        if self.multilabel and not self.return_probas:
            raise RuntimeError(
                'Set return_probas to True for multilabel classification!')

        self.bert_config = BertConfig.from_json_file(
            str(expand_path(bert_config_file)))

        if attention_probs_keep_prob is not None:
            self.bert_config.attention_probs_dropout_prob = 1.0 - attention_probs_keep_prob
        if hidden_keep_prob is not None:
            self.bert_config.hidden_dropout_prob = 1.0 - hidden_keep_prob

        self.sess_config = tf.ConfigProto(allow_soft_placement=True)
        self.sess_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=self.sess_config)

        self._init_graph()

        self._init_optimizer()

        self.sess.run(tf.global_variables_initializer())

        if pretrained_bert is not None:
            pretrained_bert = str(expand_path(pretrained_bert))

            if tf.train.checkpoint_exists(pretrained_bert) \
                    and not (self.load_path and tf.train.checkpoint_exists(str(self.load_path.resolve()))):
                logger.info('[initializing model with Bert from {}]'.format(
                    pretrained_bert))
                # Exclude optimizer and classification variables from saved variables
                var_list = self._get_saveable_variables(
                    exclude_scopes=('Optimizer', 'learning_rate', 'momentum',
                                    'output_weights', 'output_bias'))
                saver = tf.train.Saver(var_list)
                saver.restore(self.sess, pretrained_bert)

        if self.load_path is not None:
            self.load()

    def _init_graph(self):
        self._init_placeholders()

        self.bert = BertModel(
            config=self.bert_config,
            is_training=self.is_train_ph,
            input_ids=self.input_ids_ph,
            input_mask=self.input_masks_ph,
            token_type_ids=self.token_types_ph,
            use_one_hot_embeddings=False,
        )

        output_layer = self.bert.get_pooled_output()
        hidden_size = output_layer.shape[-1].value

        output_weights = tf.get_variable(
            "output_weights", [self.n_classes, hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))

        output_bias = tf.get_variable("output_bias", [self.n_classes],
                                      initializer=tf.zeros_initializer())

        with tf.variable_scope("loss"):
            output_layer = tf.nn.dropout(output_layer,
                                         keep_prob=self.keep_prob_ph)
            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

            if self.one_hot_labels:
                one_hot_labels = self.y_ph
            else:
                one_hot_labels = tf.one_hot(self.y_ph,
                                            depth=self.n_classes,
                                            dtype=tf.float32)

            self.y_predictions = tf.argmax(logits, axis=-1)
            if not self.multilabel:
                log_probs = tf.nn.log_softmax(logits, axis=-1)
                self.y_probas = tf.nn.softmax(logits, axis=-1)
                per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                                  axis=-1)
                self.loss = tf.reduce_mean(per_example_loss)
            else:
                self.y_probas = tf.nn.sigmoid(logits)
                self.loss = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=one_hot_labels, logits=logits))

    def _init_placeholders(self):
        self.input_ids_ph = tf.placeholder(shape=(None, None),
                                           dtype=tf.int32,
                                           name='ids_ph')
        self.input_masks_ph = tf.placeholder(shape=(None, None),
                                             dtype=tf.int32,
                                             name='masks_ph')
        self.token_types_ph = tf.placeholder(shape=(None, None),
                                             dtype=tf.int32,
                                             name='token_types_ph')

        if not self.one_hot_labels:
            self.y_ph = tf.placeholder(shape=(None, ),
                                       dtype=tf.int32,
                                       name='y_ph')
        else:
            self.y_ph = tf.placeholder(shape=(None, self.n_classes),
                                       dtype=tf.float32,
                                       name='y_ph')

        self.learning_rate_ph = tf.placeholder_with_default(
            0.0, shape=[], name='learning_rate_ph')
        self.keep_prob_ph = tf.placeholder_with_default(1.0,
                                                        shape=[],
                                                        name='keep_prob_ph')
        self.is_train_ph = tf.placeholder_with_default(False,
                                                       shape=[],
                                                       name='is_train_ph')

    def _init_optimizer(self):
        with tf.variable_scope('Optimizer'):
            self.global_step = tf.get_variable(
                'global_step',
                shape=[],
                dtype=tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)
            # default optimizer for Bert is Adam with fixed L2 regularization
            if self.optimizer is None:
                self.optimizer = AdamWeightDecayOptimizer(
                    learning_rate=self.learning_rate_ph,
                    weight_decay_rate=self.weight_decay_rate,
                    beta_1=0.9,
                    beta_2=0.999,
                    epsilon=1e-6,
                    exclude_from_weight_decay=[
                        "LayerNorm", "layer_norm", "bias"
                    ])

    def split(self,
              features: List[InputFeatures],
              y: Union[Optional[List[int]], List[List[int]]] = None):
        """
        Splits features: batch of InputFeatures
         on num_parts equal parts
        making num_parts batches instead
        """

        num_parts = self.gradient_accumulation_steps
        assert num_parts > 0
        assert num_parts <= len(features)

        num_features = math.ceil(len(features) + 0.0 / num_parts)
        feature_batches = [
            features[i:i + num_features] for i in range(num_parts)
        ]
        if y is not None:
            y_batches = [y[i:i + num_features] for i in range(num_parts)]
        else:
            y_batches = []
        return feature_batches, y_batches

    def train_on_batch(self,
                       features: List[InputFeatures],
                       y: Union[List[int], List[List[int]]] = None) -> Dict:
        """Train model on given batch.
        This method clls train_op using features and y (labels).
        Args:
            features: batch of InputFeatures
            y: batch of labels (class id or one-hot encoding)
        Returns:
            dict with loss and learning_rate values
        """

        # get trainable variables

        train_vars = tf.trainable_variables()
        accumulated_gradient = [
            tf.zeros_like(this_var) for this_var in train_vars
        ]
        feature_batches, y_batches = self.split(features)
        feed_dicts = [
            self.build_feed_dict(input_ids=feature_batch[0],
                                 input_masks=feature_batch[1],
                                 token_types=feature_batch[2],
                                 y=y)
            for feature_batch, y in zip(feature_batches, y_batches)
        ]
        learning_rate = max(self.get_learning_rate(), self.min_learning_rate)
        total_batch_loss = 0
        # https://stackoverflow.com/questions/59893850/how-to-accumulate-gradients-in-tensorflow-2-0
        for feed_dict in feed_dicts:
            with tf.GradientTape() as tape:
                loss_value = self.sess.run(self.loss, feed_dict=feed_dict)
                total_batch_loss += loss_value
                gradients = tape.gradient(loss_value, train_vars)
                accumulated_gradient = [(accum_grad + grad)
                                        for accum_grad, grad in zip(
                                            accumulated_gradient, gradients)]
        # Now, after executing all the tapes you needed, we apply the optimization step
        # (but first we take the average of the gradients)
        accumulated_gradient = [
            this_grad / self.gradient_accumulation_steps
            for this_grad in accumulated_gradient
        ]
        # apply optimization step
        self.optimizer.apply_gradients(zip(accumulated_gradient, train_vars))
        batch_loss = total_batch_loss / self.gradient_accumulation_steps
        return {'loss': batch_loss, 'learning_rate': learning_rate}

    def __call__(
            self, features: List[InputFeatures]
    ) -> Union[List[int], List[List[float]]]:
        """Make prediction for given features (texts).

        Args:
            features: batch of InputFeatures

        Returns:
            predicted classes or probabilities of each class

        """
        input_ids = [f.input_ids for f in features]
        input_masks = [f.input_mask for f in features]
        input_type_ids = [f.input_type_ids for f in features]

        feed_dict = self._build_feed_dict(input_ids, input_masks,
                                          input_type_ids)
        if not self.return_probas:
            pred = self.sess.run(self.y_predictions, feed_dict=feed_dict)
        else:
            pred = self.sess.run(self.y_probas, feed_dict=feed_dict)
        return pred
Example #11
0
    def _init_graph(self) -> None:
        self._init_placeholders()

        self.seq_lengths = tf.reduce_sum(self.y_masks_ph, axis=1)

        self.bert = BertModel(config=self.bert_config,
                              is_training=self.is_train_ph,
                              input_ids=self.input_ids_ph,
                              input_mask=self.input_masks_ph,
                              token_type_ids=self.token_types_ph,
                              use_one_hot_embeddings=False)

        encoder_layers = [
            self.bert.all_encoder_layers[i] for i in self.encoder_layer_ids
        ]

        with tf.variable_scope('ner'):
            layer_weights = tf.get_variable('layer_weights_',
                                            shape=len(encoder_layers),
                                            initializer=tf.ones_initializer(),
                                            trainable=True)
            layer_weights = tf.unstack(layer_weights / len(encoder_layers))
            # TODO: may be stack and reduce_sum is faster
            units = sum(w * l for w, l in zip(layer_weights, encoder_layers))
            units = tf.nn.dropout(units, keep_prob=self.keep_prob_ph)
            if self.use_birnn:
                units, _ = bi_rnn(units,
                                  self.birnn_hidden_size,
                                  cell_type=self.birnn_cell_type,
                                  seq_lengths=self.seq_lengths,
                                  name='birnn')
                units = tf.concat(units, -1)
            # TODO: maybe add one more layer?
            logits = tf.layers.dense(units,
                                     units=self.n_tags,
                                     name="output_dense")

            self.logits = self.token_from_subtoken(logits, self.y_masks_ph)

            max_length = tf.reduce_max(self.seq_lengths)
            one_hot_max_len = tf.one_hot(self.seq_lengths - 1, max_length)
            tag_mask = tf.cumsum(one_hot_max_len[:, ::-1], axis=1)[:, ::-1]

            # CRF
            if self.use_crf:
                transition_params = tf.get_variable(
                    'Transition_Params',
                    shape=[self.n_tags, self.n_tags],
                    initializer=tf.zeros_initializer())
                log_likelihood, transition_params = \
                    tf.contrib.crf.crf_log_likelihood(self.logits,
                                                      self.y_ph,
                                                      self.seq_lengths,
                                                      transition_params)
                loss_tensor = -log_likelihood
                self._transition_params = transition_params

            self.y_predictions = tf.argmax(self.logits, -1)
            self.y_probas = tf.nn.softmax(self.logits, axis=2)

        with tf.variable_scope("loss"):
            y_mask = tf.cast(tag_mask, tf.float32)
            if self.use_crf:
                self.loss = tf.reduce_mean(loss_tensor)
            else:
                self.loss = tf.losses.sparse_softmax_cross_entropy(
                    labels=self.y_ph, logits=self.logits, weights=y_mask)
Example #12
0
class BertClassifierModel(LRScheduledTFModel):
    """Bert-based model for text classification.

    It uses output from [CLS] token and predicts labels using linear transformation.

    Args:
        bert_config_file: path to Bert configuration file
        n_classes: number of classes
        keep_prob: dropout keep_prob for non-Bert layers
        one_hot_labels: set True if one-hot encoding for labels is used
        multilabel: set True if it is multi-label classification
        return_probas: set True if return class probabilites instead of most probable label needed
        attention_probs_keep_prob: keep_prob for Bert self-attention layers
        hidden_keep_prob: keep_prob for Bert hidden layers
        optimizer: name of tf.train.* optimizer or None for `AdamWeightDecayOptimizer`
        num_warmup_steps:
        weight_decay_rate: L2 weight decay for `AdamWeightDecayOptimizer`
        pretrained_bert: pretrained Bert checkpoint
        min_learning_rate: min value of learning rate if learning rate decay is used
    """

    # TODO: add warmup
    # TODO: add head-only pre-training
    def __init__(self,
                 bert_config_file,
                 n_classes,
                 keep_prob,
                 one_hot_labels=False,
                 multilabel=False,
                 return_probas=False,
                 attention_probs_keep_prob=None,
                 hidden_keep_prob=None,
                 optimizer=None,
                 num_warmup_steps=None,
                 weight_decay_rate=0.01,
                 pretrained_bert=None,
                 min_learning_rate=1e-06,
                 **kwargs) -> None:
        super().__init__(**kwargs)

        self.return_probas = return_probas
        self.n_classes = n_classes
        self.min_learning_rate = min_learning_rate
        self.keep_prob = keep_prob
        self.one_hot_labels = one_hot_labels
        self.multilabel = multilabel
        self.optimizer = optimizer
        self.num_warmup_steps = num_warmup_steps
        self.weight_decay_rate = weight_decay_rate

        if self.multilabel and not self.one_hot_labels:
            raise RuntimeError(
                'Use one-hot encoded labels for multilabel classification!')

        if self.multilabel and not self.return_probas:
            raise RuntimeError(
                'Set return_probas to True for multilabel classification!')

        self.bert_config = BertConfig.from_json_file(
            str(expand_path(bert_config_file)))

        if attention_probs_keep_prob is not None:
            self.bert_config.attention_probs_dropout_prob = 1.0 - attention_probs_keep_prob
        if hidden_keep_prob is not None:
            self.bert_config.hidden_dropout_prob = 1.0 - hidden_keep_prob

        self.sess_config = tf.ConfigProto(allow_soft_placement=True)
        self.sess_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=self.sess_config)

        self._init_graph()

        self._init_optimizer()

        self.sess.run(tf.global_variables_initializer())

        if pretrained_bert is not None:
            pretrained_bert = str(expand_path(pretrained_bert))

        if tf.train.checkpoint_exists(pretrained_bert) \
                and not tf.train.checkpoint_exists(str(self.load_path.resolve())):
            logger.info('[initializing model with Bert from {}]'.format(
                pretrained_bert))
            # Exclude optimizer and classification variables from saved variables
            var_list = self._get_saveable_variables(
                exclude_scopes=('Optimizer', 'learning_rate', 'momentum',
                                'output_weights', 'output_bias'))
            saver = tf.train.Saver(var_list)
            saver.restore(self.sess, pretrained_bert)

        if self.load_path is not None:
            self.load()

    def _init_graph(self):
        self._init_placeholders()

        self.bert = BertModel(
            config=self.bert_config,
            is_training=self.is_train_ph,
            input_ids=self.input_ids_ph,
            input_mask=self.input_masks_ph,
            token_type_ids=self.token_types_ph,
            use_one_hot_embeddings=False,
        )

        output_layer = self.bert.get_pooled_output()
        hidden_size = output_layer.shape[-1].value

        output_weights = tf.get_variable(
            "output_weights", [self.n_classes, hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))

        output_bias = tf.get_variable("output_bias", [self.n_classes],
                                      initializer=tf.zeros_initializer())

        with tf.variable_scope("loss"):
            output_layer = tf.nn.dropout(output_layer,
                                         keep_prob=self.keep_prob_ph)
            logits = tf.matmul(output_layer, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)

            if self.one_hot_labels:
                one_hot_labels = self.y_ph
            else:
                one_hot_labels = tf.one_hot(self.y_ph,
                                            depth=self.n_classes,
                                            dtype=tf.float32)

            self.y_predictions = tf.argmax(logits, axis=-1)
            if not self.multilabel:
                log_probs = tf.nn.log_softmax(logits, axis=-1)
                self.y_probas = tf.nn.softmax(logits, axis=-1)
                per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                                  axis=-1)
                self.loss = tf.reduce_mean(per_example_loss)
            else:
                self.y_probas = tf.nn.sigmoid(logits)
                self.loss = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        labels=one_hot_labels, logits=logits))

    def _init_placeholders(self):
        self.input_ids_ph = tf.placeholder(shape=(None, None),
                                           dtype=tf.int32,
                                           name='ids_ph')
        self.input_masks_ph = tf.placeholder(shape=(None, None),
                                             dtype=tf.int32,
                                             name='masks_ph')
        self.token_types_ph = tf.placeholder(shape=(None, None),
                                             dtype=tf.int32,
                                             name='token_types_ph')

        if not self.one_hot_labels:
            self.y_ph = tf.placeholder(shape=(None, ),
                                       dtype=tf.int32,
                                       name='y_ph')
        else:
            self.y_ph = tf.placeholder(shape=(None, self.n_classes),
                                       dtype=tf.float32,
                                       name='y_ph')

        self.learning_rate_ph = tf.placeholder_with_default(
            0.0, shape=[], name='learning_rate_ph')
        self.keep_prob_ph = tf.placeholder_with_default(1.0,
                                                        shape=[],
                                                        name='keep_prob_ph')
        self.is_train_ph = tf.placeholder_with_default(False,
                                                       shape=[],
                                                       name='is_train_ph')

    def _init_optimizer(self):
        with tf.variable_scope('Optimizer'):
            self.global_step = tf.get_variable(
                'global_step',
                shape=[],
                dtype=tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)
            # default optimizer for Bert is Adam with fixed L2 regularization
            if self.optimizer is None:

                self.train_op = self.get_train_op(
                    self.loss,
                    learning_rate=self.learning_rate_ph,
                    optimizer=AdamWeightDecayOptimizer,
                    weight_decay_rate=self.weight_decay_rate,
                    beta_1=0.9,
                    beta_2=0.999,
                    epsilon=1e-6,
                    exclude_from_weight_decay=[
                        "LayerNorm", "layer_norm", "bias"
                    ])
            else:
                self.train_op = self.get_train_op(
                    self.loss, learning_rate=self.learning_rate_ph)

            if self.optimizer is None:
                new_global_step = self.global_step + 1
                self.train_op = tf.group(
                    self.train_op, [self.global_step.assign(new_global_step)])

    def _build_feed_dict(self, input_ids, input_masks, token_types, y=None):
        feed_dict = {
            self.input_ids_ph: input_ids,
            self.input_masks_ph: input_masks,
            self.token_types_ph: token_types,
        }
        if y is not None:
            feed_dict.update({
                self.y_ph:
                y,
                self.learning_rate_ph:
                max(self.get_learning_rate(), self.min_learning_rate),
                self.keep_prob_ph:
                self.keep_prob,
                self.is_train_ph:
                True,
            })

        return feed_dict

    def train_on_batch(self, features: List[InputFeatures],
                       y: Union[List[int], List[List[int]]]) -> Dict:
        """Train model on given batch.
        This method calls train_op using features and y (labels).

        Args:
            features: batch of InputFeatures
            y: batch of labels (class id or one-hot encoding)

        Returns:
            dict with loss and learning_rate values

        """
        input_ids = [f.input_ids for f in features]
        input_masks = [f.input_mask for f in features]
        input_type_ids = [f.input_type_ids for f in features]

        feed_dict = self._build_feed_dict(input_ids, input_masks,
                                          input_type_ids, y)

        _, loss = self.sess.run([self.train_op, self.loss],
                                feed_dict=feed_dict)
        return {
            'loss': loss,
            'learning_rate': feed_dict[self.learning_rate_ph]
        }

    def __call__(
            self, features: List[InputFeatures]
    ) -> Union[List[int], List[List[float]]]:
        """Make prediction for given features (texts).

        Args:
            features: batch of InputFeatures

        Returns:
            predicted classes or probabilities of each class

        """
        input_ids = [f.input_ids for f in features]
        input_masks = [f.input_mask for f in features]
        input_type_ids = [f.input_type_ids for f in features]

        feed_dict = self._build_feed_dict(input_ids, input_masks,
                                          input_type_ids)
        if not self.return_probas:
            pred = self.sess.run(self.y_predictions, feed_dict=feed_dict)
        else:
            pred = self.sess.run(self.y_probas, feed_dict=feed_dict)
        return pred
Example #13
0
class BertRankerModel(LRScheduledTFModel):
    # TODO: docs
    # TODO: add head-only pre-training
    def __init__(self,
                 bert_config_file,
                 n_classes,
                 keep_prob,
                 batch_size,
                 num_ranking_samples,
                 one_hot_labels=False,
                 attention_probs_keep_prob=None,
                 hidden_keep_prob=None,
                 pretrained_bert=None,
                 resps=None,
                 resp_vecs=None,
                 resp_features=None,
                 resp_eval=True,
                 conts=None,
                 cont_vecs=None,
                 cont_features=None,
                 cont_eval=True,
                 bot_mode=0,
                 min_learning_rate=1e-06,
                 **kwargs) -> None:
        super().__init__(**kwargs)

        self.batch_size = batch_size
        self.num_ranking_samples = num_ranking_samples
        self.n_classes = n_classes
        self.min_learning_rate = min_learning_rate
        self.keep_prob = keep_prob
        self.one_hot_labels = one_hot_labels
        self.batch_size = batch_size
        self.resp_eval = resp_eval
        self.resps = resps
        self.resp_vecs = resp_vecs
        self.cont_eval = cont_eval
        self.conts = conts
        self.cont_vecs = cont_vecs
        self.bot_mode = bot_mode

        self.bert_config = BertConfig.from_json_file(
            str(expand_path(bert_config_file)))

        if attention_probs_keep_prob is not None:
            self.bert_config.attention_probs_dropout_prob = 1.0 - attention_probs_keep_prob
        if hidden_keep_prob is not None:
            self.bert_config.hidden_dropout_prob = 1.0 - hidden_keep_prob

        self.sess_config = tf.ConfigProto(allow_soft_placement=True)
        self.sess_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=self.sess_config)

        self._init_graph()

        self._init_optimizer()

        self.sess.run(tf.global_variables_initializer())

        if pretrained_bert is not None:
            pretrained_bert = str(expand_path(pretrained_bert))

        if tf.train.checkpoint_exists(pretrained_bert) \
                and not tf.train.checkpoint_exists(str(self.load_path.resolve())):
            logger.info('[initializing model with Bert from {}]'.format(
                pretrained_bert))
            # Exclude optimizer and classification variables from saved variables
            var_list = self._get_saveable_variables(
                exclude_scopes=('Optimizer', 'learning_rate', 'momentum',
                                'classification'))
            saver = tf.train.Saver(var_list)
            saver.restore(self.sess, pretrained_bert)

        if self.load_path is not None:
            self.load()

        if self.resp_eval:
            assert (self.resps is not None)
            assert (self.resp_vecs is not None)
        if self.cont_eval:
            assert (self.conts is not None)
            assert (self.cont_vecs is not None)
        if self.resp_eval and self.cont_eval:
            assert (len(self.resps) == len(self.conts))

    def _init_graph(self):
        self._init_placeholders()
        with tf.variable_scope("model"):
            self.bert = BertModel(
                config=self.bert_config,
                is_training=self.is_train_ph,
                input_ids=self.input_ids_ph,
                input_mask=self.input_masks_ph,
                token_type_ids=self.token_types_ph,
                use_one_hot_embeddings=False,
            )

        output_layer_a = self.bert.get_pooled_output()

        with tf.variable_scope("loss"):
            with tf.variable_scope("loss"):
                self.loss = tf.contrib.losses.metric_learning.npairs_loss(
                    self.y_ph, output_layer_a, output_layer_a)
                self.y_probas = output_layer_a

    def _init_placeholders(self):
        self.input_ids_ph = tf.placeholder(shape=(None, None),
                                           dtype=tf.int32,
                                           name='ids_ph')
        self.input_masks_ph = tf.placeholder(shape=(None, None),
                                             dtype=tf.int32,
                                             name='masks_ph')
        self.token_types_ph = tf.placeholder(shape=(None, None),
                                             dtype=tf.int32,
                                             name='token_types_ph')

        if not self.one_hot_labels:
            self.y_ph = tf.placeholder(shape=(None, ),
                                       dtype=tf.int32,
                                       name='y_ph')
        else:
            self.y_ph = tf.placeholder(shape=(None, self.n_classes),
                                       dtype=tf.float32,
                                       name='y_ph')

        self.learning_rate_ph = tf.placeholder_with_default(
            0.0, shape=[], name='learning_rate_ph')
        self.keep_prob_ph = tf.placeholder_with_default(1.0,
                                                        shape=[],
                                                        name='keep_prob_ph')
        self.is_train_ph = tf.placeholder_with_default(False,
                                                       shape=[],
                                                       name='is_train_ph')

    def _init_optimizer(self):
        # TODO: use AdamWeightDecay optimizer
        with tf.variable_scope('Optimizer'):
            self.global_step = tf.get_variable(
                'global_step',
                shape=[],
                dtype=tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)
            self.train_op = self.get_train_op(
                self.loss, learning_rate=self.learning_rate_ph)

    def _build_feed_dict(self, input_ids, input_masks, token_types, y=None):
        feed_dict = {
            self.input_ids_ph: input_ids,
            self.input_masks_ph: input_masks,
            self.token_types_ph: token_types,
        }
        if y is not None:
            feed_dict.update({
                self.y_ph:
                y,
                self.learning_rate_ph:
                max(self.get_learning_rate(), self.min_learning_rate),
                self.keep_prob_ph:
                self.keep_prob,
                self.is_train_ph:
                True,
            })

        return feed_dict

    def train_on_batch(self, features, y):
        pass

    def __call__(self, features_list):
        pred = []
        for features in features_list:
            input_ids = [f.input_ids for f in features]
            input_masks = [f.input_mask for f in features]
            input_type_ids = [f.input_type_ids for f in features]
            feed_dict = self._build_feed_dict(input_ids, input_masks,
                                              input_type_ids)
            p = self.sess.run(self.y_probas, feed_dict=feed_dict)
            if len(p.shape) == 1:
                p = np.expand_dims(p, 0)
            pred.append(p)
        pred = np.vstack(pred)
        pred = pred / np.linalg.norm(pred, keepdims=True)
        bs = pred.shape[0]
        if self.bot_mode == 0:
            s = pred @ self.resp_vecs.T
            ids = np.argmax(s, 1)
            ans = [[self.resps[ids[i]] for i in range(bs)],
                   [s[i][ids[i]] for i in range(bs)]]
        if self.bot_mode == 1:
            sr = (pred @ self.resp_vecs.T + 1) / 2
            sc = (pred @ self.cont_vecs.T + 1) / 2
            ids = np.argsort(sr, 1)[:, -10:]
            sc = [sc[i, ids[i]] for i in range(bs)]
            ids = [
                sorted(zip(ids[i], sc[i]), key=itemgetter(1), reverse=True)
                for i in range(bs)
            ]
            sc = [list(map(lambda x: x[1], ids[i])) for i in range(bs)]
            ids = [list(map(lambda x: x[0], ids[i])) for i in range(bs)]
            ans = [[self.resps[ids[i][0]] for i in range(bs)],
                   [float(sc[i][0]) for i in range(bs)]]
        if self.bot_mode == 2:
            sr = (pred @ self.resp_vecs.T + 1) / 2
            sc = (pred @ self.cont_vecs.T + 1) / 2
            ids = np.argsort(sc, 1)[:, -10:]
            sr = [sr[i, ids[i]] for i in range(bs)]
            ids = [
                sorted(zip(ids[i], sr[i]), key=itemgetter(1), reverse=True)
                for i in range(bs)
            ]
            sr = [list(map(lambda x: x[1], ids[i])) for i in range(bs)]
            ids = [list(map(lambda x: x[0], ids[i])) for i in range(bs)]
            ans = [[self.resps[ids[i][0]] for i in range(bs)],
                   [float(sr[i][0]) for i in range(bs)]]
        if self.bot_mode == 3:
            sr = pred @ self.resp_vecs.T
            sc = pred @ self.cont_vecs.T
            s = sr + sc
            ids = np.argmax(s, 1)
            ans = [[self.resps[ids[i]] for i in range(bs)],
                   [float(s[i][ids[i]]) for i in range(bs)]]
        return ans
class BertAsSummarizer(TFModel):
    """Naive Extractive Summarization model based on BERT.
    BERT model was trained on Masked Language Modeling (MLM) and Next Sentence Prediction (NSP) tasks.
    NSP head was trained to detect in ``[CLS] text_a [SEP] text_b [SEP]`` if text_b follows text_a in original document.

    This NSP head can be used to stack sentences from a long document, based on a initial sentence:

    summary_0 = init_sentence

    summary_1 = summary_0 + argmax(nsp_score(candidates))

    summary_2 = summary_1 + argmax(nsp_score(candidates))

    ...

    , where candidates are all sentences from a document.

    Args:
        bert_config_file: path to Bert configuration file
        pretrained_bert: path to pretrained Bert checkpoint
        vocab_file: path to Bert vocabulary
        max_summary_length: limit on summary length, number of sentences is used if ``max_summary_length_in_tokens``
            is set to False, else number of tokens is used.
        max_summary_length_in_tokens: Use number of tokens as length of summary.
            Defaults to ``False``.
        max_seq_length: max sequence length in subtokens, including ``[SEP]`` and ``[CLS]`` tokens.
            `max_seq_length` is used in Bert to compute NSP scores. Defaults to ``128``.
        do_lower_case: set ``True`` if lowercasing is needed. Defaults to ``False``.
        lang: use ru_sent_tokenizer for 'ru' and ntlk.sent_tokener for other languages.
            Defaults to ``'ru'``.
    """

    def __init__(self, bert_config_file: str,
                 pretrained_bert: str,
                 vocab_file: str,
                 max_summary_length: int,
                 max_summary_length_in_tokens: Optional[bool] = False,
                 max_seq_length: Optional[int] = 128,
                 do_lower_case: Optional[bool] = False,
                 lang: Optional[str] = 'ru',
                 **kwargs) -> None:

        self.max_summary_length = max_summary_length
        self.max_summary_length_in_tokens = max_summary_length_in_tokens
        self.bert_config = BertConfig.from_json_file(str(expand_path(bert_config_file)))

        self.bert_preprocessor = BertPreprocessor(vocab_file=vocab_file, do_lower_case=do_lower_case,
                                                  max_seq_length=max_seq_length)

        self.tokenize_reg = re.compile(r"[\w']+|[^\w ]")

        if lang == 'ru':
            from ru_sent_tokenize import ru_sent_tokenize
            self.sent_tokenizer = ru_sent_tokenize
        else:
            from nltk import sent_tokenize
            self.sent_tokenizer = sent_tokenize

        self.sess_config = tf.ConfigProto(allow_soft_placement=True)
        self.sess_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=self.sess_config)

        self._init_graph()

        self.sess.run(tf.global_variables_initializer())

        if pretrained_bert is not None:
            pretrained_bert = str(expand_path(pretrained_bert))

            if tf.train.checkpoint_exists(pretrained_bert):
                logger.info('[initializing model with Bert from {}]'.format(pretrained_bert))
                tvars = tf.trainable_variables()
                assignment_map, _ = get_assignment_map_from_checkpoint(tvars, pretrained_bert)
                tf.train.init_from_checkpoint(pretrained_bert, assignment_map)

    def _init_graph(self):
        self._init_placeholders()

        self.bert = BertModel(config=self.bert_config,
                              is_training=self.is_train_ph,
                              input_ids=self.input_ids_ph,
                              input_mask=self.input_masks_ph,
                              token_type_ids=self.token_types_ph,
                              use_one_hot_embeddings=False,
                              )
        # next sentence prediction head
        with tf.variable_scope("cls/seq_relationship"):
            output_weights = tf.get_variable(
                "output_weights",
                shape=[2, self.bert_config.hidden_size],
                initializer=create_initializer(self.bert_config.initializer_range))
            output_bias = tf.get_variable(
                "output_bias", shape=[2], initializer=tf.zeros_initializer())

        nsp_logits = tf.matmul(self.bert.get_pooled_output(), output_weights, transpose_b=True)
        nsp_logits = tf.nn.bias_add(nsp_logits, output_bias)
        self.nsp_probs = tf.nn.softmax(nsp_logits, axis=-1)

    def _init_placeholders(self):
        self.input_ids_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='ids_ph')
        self.input_masks_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='masks_ph')
        self.token_types_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='token_types_ph')

        self.is_train_ph = tf.placeholder_with_default(False, shape=[], name='is_train_ph')

    def _build_feed_dict(self, input_ids, input_masks, token_types):
        feed_dict = {
            self.input_ids_ph: input_ids,
            self.input_masks_ph: input_masks,
            self.token_types_ph: token_types,
        }
        return feed_dict

    def _get_nsp_predictions(self, sentences: List[str], candidates: List[str]):
        """Compute NextSentence probability for every (sentence_i, candidate_i) pair.

        [CLS] sentence_i [SEP] candidate_i [SEP]

        Args:
            sentences: list of sentences
            candidates: list of candidates to be the next sentence

        Returns:
            probabilities that candidate is a next sentence
        """
        features = self.bert_preprocessor(texts_a=sentences, texts_b=candidates)
        input_ids = [f.input_ids for f in features]
        input_masks = [f.input_mask for f in features]
        input_type_ids = [f.input_type_ids for f in features]
        feed_dict = self._build_feed_dict(input_ids, input_masks, input_type_ids)
        nsp_probs = self.sess.run(self.nsp_probs, feed_dict=feed_dict)
        return nsp_probs[:, 0]

    def __call__(self, texts: List[str], init_sentences: Optional[List[str]] = None) -> List[List[str]]:
        """Builds summary for text from `texts`

        Args:
            texts: texts to build summaries for
            init_sentences: ``init_sentence`` is used as the first sentence in summary.
                Defaults to None.

        Returns:
            List[List[str]]: summaries tokenized on sentences
        """
        summaries = []
        # build summaries for each text, init_sentence pair
        if init_sentences is None:
            init_sentences = [None] * len(texts)

        for text, init_sentence in zip(texts, init_sentences):
            text_sentences = self.sent_tokenizer(text)

            if init_sentence is None:
                init_sentence = text_sentences[0]
                text_sentences = text_sentences[1:]

            # remove duplicates
            text_sentences = list(set(text_sentences))
            # remove init_sentence from text sentences
            text_sentences = [sent for sent in text_sentences if sent != init_sentence]

            summary = [init_sentence]
            if self.max_summary_length_in_tokens:
                # get length in tokens
                def get_length(x):
                    return len(self.tokenize_reg.findall(' '.join(x)))
            else:
                # get length as number of sentences
                get_length = len

            candidates = text_sentences[:]
            while len(candidates) > 0:
                # todo: use batches
                candidates_scores = [self._get_nsp_predictions([' '.join(summary)], [cand]) for cand in candidates]
                best_candidate_idx = np.argmax(candidates_scores)
                best_candidate = candidates[best_candidate_idx]
                del candidates[best_candidate_idx]
                if get_length(summary + [best_candidate]) > self.max_summary_length:
                    break
                summary = summary + [best_candidate]
            summaries += [summary]
        return summaries

    def train_on_batch(self, **kwargs):
        raise NotImplementedError