예제 #1
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        unique_ids = features["unique_ids"]
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        input_type_ids = features["input_type_ids"]

        model = modeling.BertModel(
            config=bert_config,
            is_training=False,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=input_type_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        if mode != tf.estimator.ModeKeys.PREDICT:
            raise ValueError("Only PREDICT modes are supported: %s" % (mode))

        tvars = tf.trainable_variables()
        scaffold_fn = None
        (assignment_map, initialized_variable_names
         ) = modeling.get_assignment_map_from_checkpoint(
             tvars, init_checkpoint)
        if use_tpu:

            def tpu_scaffold():
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                return tf.train.Scaffold()

            scaffold_fn = tpu_scaffold
        else:
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        all_layers = model.get_all_encoder_layers()

        predictions = {
            "unique_id": unique_ids,
        }

        for (i, layer_index) in enumerate(layer_indexes):
            predictions["layer_output_%d" % i] = all_layers[layer_index]

        output_spec = tf.contrib.tpu.TPUEstimatorSpec(mode=mode,
                                                      predictions=predictions,
                                                      scaffold_fn=scaffold_fn)
        return output_spec
예제 #2
0
파일: bert.py 프로젝트: googol-lab/DAPPLE
    def __init__(self,
                 bert_config_file,
                 max_seq_length,
                 is_training,
                 input_ids,
                 input_mask,
                 segment_ids,
                 labels,
                 use_one_hot_embeddings,
                 model_type='classification',
                 kwargs=None):

        bert_config = modeling.BertConfig.from_json_file(bert_config_file)
        if max_seq_length > bert_config.max_position_embeddings:
            raise ValueError(
                "Cannot use sequence length %d because the BERT model "
                "was only trained up to sequence length %d" %
                (max_seq_length, bert_config.max_position_embeddings))

        self.model = modeling.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        self.bert_config = bert_config
        self.kwargs = kwargs
        self.labels = labels
        self.input_ids = input_ids

        if model_type == 'classification':
            self.build_output_layer_classification()
        elif model_type == 'regression':
            self.build_output_layer_regression()
        elif model_type == 'mrc':
            self.build_output_layer_squad()
        elif model_type == 'pretrain':
            self.build_pretrain()
        else:
            raise ValueError("model_type should be one of ['classification', "
                             "'regression', pretrain', 'mrc'].")

        self.saver = tf.train.Saver(var_list=tf.global_variables(),
                                    max_to_keep=2)
예제 #3
0
    def _make_graph(self, bert_config):

        self._init_placeholders()

        self.input_mask = tf.sequence_mask(
            tf.to_int32(self.encoder_inputs_length),
            tf.reduce_max(self.encoder_inputs_length),
            dtype=tf.int32)

        self.segment_ids = tf.sequence_mask(
            tf.to_int32(self.encoder_inputs_length),
            tf.reduce_max(self.encoder_inputs_length),
            dtype=tf.int32)

        self.segment_ids = 0 * self.segment_ids

        old_ = True if ((self.args.test_file is not None
                         and 'yelp' in self.args.test_file) or
                        (self.args.input_file is not None
                         and 'yelp' in self.args.input_file)) else False

        self.model = modeling.BertModel(
            config=bert_config,
            is_training=(self.mode == 'Train'),
            input_ids=self.encoder_inputs,
            input_mask=self.input_mask,
            token_type_ids=self.segment_ids,
            use_one_hot_embeddings=False,
            word_embedding_trainable=(self.mode == 'Train'),
        )

        encoder_outputs = self.model.get_pooled_output()

        with tf.variable_scope("classification") as scope:
            fc_output = tf.layers.dense(encoder_outputs,
                                        1024,
                                        activation=tf.nn.relu)
            projection_layer = layers_core.Dense(
                units=self.args.output_classes, name="projection_layer")
            with tf.device(get_device_str(self.args.num_gpus)):
                self.logits = tf.nn.tanh(projection_layer(
                    fc_output))  # [batch size, output_classes]

        # if self.mode == "Train":
        self._init_optimizer()
예제 #4
0
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
                 labels, num_labels, use_one_hot_embeddings):
  """Creates a classification model."""
  model = modeling.BertModel(
      config=bert_config,
      is_training=is_training,
      input_ids=input_ids,
      input_mask=input_mask,
      token_type_ids=segment_ids,
      use_one_hot_embeddings=use_one_hot_embeddings)

  # In the demo, we are doing a simple classification task on the entire
  # segment.
  #
  # If you want to use the token-level output, use model.get_sequence_output()
  # instead.
  output_layer = model.get_pooled_output()

  hidden_size = output_layer.shape[-1].value

  output_weights = tf.get_variable(
      "output_weights", [num_labels, hidden_size],
      initializer=tf.truncated_normal_initializer(stddev=0.02))

  output_bias = tf.get_variable(
      "output_bias", [num_labels], initializer=tf.zeros_initializer())

  with tf.variable_scope("loss"):
    if is_training:
      # I.e., 0.1 dropout
      output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    probabilities = tf.nn.softmax(logits, axis=-1)
    log_probs = tf.nn.log_softmax(logits, axis=-1)

    one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)

    per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
    loss = tf.reduce_mean(per_example_loss)

    return (loss, per_example_loss, logits, probabilities)
예제 #5
0
def model_fn_builder(bert_config, init_checkpoint, use_one_hot_embeddings):
    """Returns `model_fn` closure for TPUEstimator."""

    input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length],
                               name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length],
                                name='input_mask')
    input_type_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length],
                                    name='segment_ids')
    model = modeling.BertModel(config=bert_config,
                               is_training=False,
                               input_ids=input_ids,
                               input_mask=input_mask,
                               token_type_ids=input_type_ids,
                               use_one_hot_embeddings=use_one_hot_embeddings)

    all_layer_outputs = [model.get_word_embedding_output()]
    all_layer_outputs += model.get_all_encoder_layers()

    if FLAGS.high_layer_idx == FLAGS.low_layer_idx:
        if FLAGS.use_cls_token:
            outputs = model.get_pooled_output()
        else:
            outputs = all_layer_outputs[FLAGS.low_layer_idx]
            outputs = mean_pool(outputs, input_mask)
    else:
        low_outputs = all_layer_outputs[FLAGS.low_layer_idx]
        low_outputs = mean_pool(low_outputs, input_mask)
        if FLAGS.use_cls_token:
            high_outputs = model.get_pooled_output()
        else:
            high_outputs = all_layer_outputs[FLAGS.high_layer_idx]
            high_outputs = mean_pool(high_outputs, input_mask)
        outputs = (low_outputs, high_outputs)

    tvars = tf.trainable_variables()
    (assignment_map,
     initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
         tvars, init_checkpoint)

    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
    return input_ids, input_mask, input_type_ids, outputs
예제 #6
0
        def create_model(self):
            input_ids = BertModelTest.ids_tensor(
                [self.batch_size, self.seq_length], self.vocab_size)

            input_mask = None
            if self.use_input_mask:
                input_mask = BertModelTest.ids_tensor(
                    [self.batch_size, self.seq_length], vocab_size=2)

            token_type_ids = None
            if self.use_token_type_ids:
                token_type_ids = BertModelTest.ids_tensor(
                    [self.batch_size, self.seq_length], self.type_vocab_size)

            config = modeling.BertConfig(
                vocab_size=self.vocab_size,
                hidden_size=self.hidden_size,
                num_hidden_layers=self.num_hidden_layers,
                num_attention_heads=self.num_attention_heads,
                intermediate_size=self.intermediate_size,
                hidden_act=self.hidden_act,
                hidden_dropout_prob=self.hidden_dropout_prob,
                attention_probs_dropout_prob=self.attention_probs_dropout_prob,
                max_position_embeddings=self.max_position_embeddings,
                type_vocab_size=self.type_vocab_size,
                initializer_range=self.initializer_range)

            model = modeling.BertModel(config=config,
                                       is_training=self.is_training,
                                       input_ids=input_ids,
                                       input_mask=input_mask,
                                       token_type_ids=token_type_ids,
                                       scope=self.scope)

            outputs = {
                "embedding_output": model.get_embedding_output(),
                "sequence_output": model.get_sequence_output(),
                "pooled_output": model.get_pooled_output(),
                "all_encoder_layers": model.get_all_encoder_layers(),
            }
            return outputs
예제 #7
0
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
                 use_one_hot_embeddings):
    """Creates a classification model."""
    model = modeling.BertModel(config=bert_config,
                               is_training=is_training,
                               input_ids=input_ids,
                               input_mask=input_mask,
                               token_type_ids=segment_ids,
                               use_one_hot_embeddings=use_one_hot_embeddings)

    final_hidden = model.get_sequence_output()

    final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
    batch_size = final_hidden_shape[0]
    seq_length = final_hidden_shape[1]
    hidden_size = final_hidden_shape[2]

    output_weights = tf.get_variable(
        "cls/squad/output_weights", [2, hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))

    output_bias = tf.get_variable("cls/squad/output_bias", [2],
                                  initializer=tf.zeros_initializer())

    final_hidden_matrix = tf.reshape(final_hidden,
                                     [batch_size * seq_length, hidden_size])
    logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)

    logits = tf.reshape(logits, [batch_size, seq_length, 2])
    logits = tf.transpose(logits, [2, 0, 1])

    unstacked_logits = tf.unstack(logits, axis=0)

    (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])

    return (start_logits, end_logits)
예제 #8
0
    def cnn(self):
        '''Get the final token-level output of BERT model using get_sequence_output function, and use it as the input embeddings of CNN model.
        '''
        batch_tensor_list = []
        ATTENTION_SIZE = 200
        # i代表32个包中的一个。     对一个包里的句子分别使用bert得到输出
        # 相当于遍历每一个bag
        for i in range(len(self.input_ids)):  # input_ids(32,7,170)
            # input_ids = (7个句子, 170个词)
            with tf.name_scope('bert'):
                bert_model = modeling.BertModel(
                    config=self.bert_config,
                    is_training=self.config.is_training,
                    input_ids=self.input_ids[i],  # input_ids[i](7,170)
                    input_mask=self.input_mask[i],
                    token_type_ids=self.segment_ids[i],
                    use_one_hot_embeddings=self.config.use_one_hot_embeddings)
                embedding_inputs = bert_model.get_sequence_output()
            # embedding_inputs包含了7个句子,即(7,170,768),先通过下面的3层CNN,合成7个句向量:(7,768)
            '''Use three convolution kernels to do convolution and pooling, and concat the three resutls.'''
            with tf.name_scope('conv'):
                pooled_outputs = []
                for i, filter_size in enumerate(self.config.filter_sizes):
                    with tf.variable_scope("conv-maxpool-%s" % filter_size,
                                           reuse=False):
                        conv = tf.layers.conv1d(embedding_inputs,
                                                self.config.num_filters,
                                                filter_size,
                                                name='conv1d')
                        pooled = tf.reduce_max(conv,
                                               reduction_indices=[1],
                                               name='gmp')
                        pooled_outputs.append(pooled)

                num_filters_total = self.config.num_filters * len(
                    self.config.filter_sizes)
                h_pool = tf.concat(pooled_outputs, 1)
                outputs = tf.reshape(h_pool, [-1, num_filters_total])
            '''Add full connection layer and dropout layer'''
            with tf.name_scope('fc'):
                fc = tf.layers.dense(outputs,
                                     self.config.hidden_dim,
                                     name='fc1')
                fc = tf.nn.dropout(fc, self.keep_prob)
                fc = tf.nn.relu(fc)

            # 再用7个句向量(7,768),用Attention,合并为1个bag向量(1,768)
            # TODO 写代码!!
            # Attention layer
            with tf.name_scope('Attention_layer'):
                attention_output, alphas = attention(fc,
                                                     ATTENTION_SIZE,
                                                     return_alphas=True)
                tf.summary.histogram('alphas', alphas)

            # 得到1个bag向量attention_output=(1, 768),再for循环batch里的所有bag,拼接32个bag向量,得到fc=(32, 768)
            batch_tensor_list.append(attention_output)
        # 拼接32个bag向量
        batch_tensor = tf.stack(batch_tensor_list)
        # batch_tensor=(32, 768)
        '''logits'''
        with tf.name_scope('logits'):
            # fc=(32, 768), dense以后变成 logits=(32,5)
            # self.logits = tf.layers.dense(batch_tensor, self.config.num_labels, name='logits')
            self.logits = tf.layers.dense(batch_tensor, 5, name='logits')

            self.prob = tf.nn.softmax(self.logits)
            self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits),
                                        1)  # y_pred_cls=(32,)
        '''Calculate loss. Convert predicted labels into one hot form. '''
        with tf.name_scope('loss'):

            log_probs = tf.nn.log_softmax(self.logits,
                                          axis=-1)  # (32,5)->(32, 5)
            one_hot_labels = tf.one_hot(self.labels,
                                        depth=self.config.num_labels,
                                        dtype=tf.float32)  # (32,5)
            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
            self.loss = tf.reduce_mean(per_example_loss)
        '''optimizer'''
        with tf.name_scope('optimizer'):
            optimizer = tf.train.AdamOptimizer(self.config.lr)
            gradients, variables = zip(*optimizer.compute_gradients(self.loss))
            gradients, _ = tf.clip_by_global_norm(gradients, self.config.clip)
            self.optim = optimizer.apply_gradients(
                zip(gradients, variables), global_step=self.global_step)
        '''accuracy'''
        with tf.name_scope('accuracy'):
            correct_pred = tf.equal(self.labels, self.y_pred_cls)
            self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
예제 #9
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = features["next_sentence_labels"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = modeling.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids,
            input_mask=input_mask,
            token_type_ids=segment_ids,
            use_one_hot_embeddings=use_one_hot_embeddings)

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_log_probs) = get_masked_lm_output(
             bert_config, model.get_sequence_output(),
             model.get_embedding_table(), masked_lm_positions, masked_lm_ids,
             masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss,
         next_sentence_log_probs) = get_next_sentence_output(
             bert_config, model.get_pooled_output(), next_sentence_labels)

        total_loss = masked_lm_loss + next_sentence_loss

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, use_tpu)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs,
                    [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_weights, next_sentence_example_loss,
                next_sentence_log_probs, next_sentence_labels
            ])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only TRAIN and EVAL modes are supported: %s" %
                             (mode))

        return output_spec
예제 #10
0
def optimization_inversion():
    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)
    cls_id = tokenizer.vocab['[CLS]']
    sep_id = tokenizer.vocab['[SEP]']
    mask_id = tokenizer.vocab['[MASK]']

    _, _, x, y = load_inversion_data()
    filters = [cls_id, sep_id, mask_id]
    y = filter_labels(y[0], filters)

    batch_size = FLAGS.batch_size
    seq_len = FLAGS.seq_len
    max_iters = FLAGS.max_iters

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    input_ids = tf.ones((batch_size, seq_len + 2), tf.int32)
    input_mask = tf.ones_like(input_ids, tf.int32)
    input_type_ids = tf.zeros_like(input_ids, tf.int32)

    model = modeling.BertModel(config=bert_config,
                               is_training=False,
                               input_ids=input_ids,
                               input_mask=input_mask,
                               token_type_ids=input_type_ids,
                               use_one_hot_embeddings=False)

    bert_vars = tf.trainable_variables()

    (assignment_map,
     _) = modeling.get_assignment_map_from_checkpoint(bert_vars,
                                                      FLAGS.init_checkpoint)
    tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map)
    word_emb = model.embedding_table

    batch_cls_ids = tf.ones((batch_size, 1), tf.int32) * cls_id
    batch_sep_ids = tf.ones((batch_size, 1), tf.int32) * sep_id
    cls_emb = tf.nn.embedding_lookup(word_emb, batch_cls_ids)
    sep_emb = tf.nn.embedding_lookup(word_emb, batch_sep_ids)

    prob_mask = np.zeros((bert_config.vocab_size, ), np.float32)
    prob_mask[filters] = -1e9
    prob_mask = tf.constant(prob_mask, dtype=np.float32)

    logit_inputs = tf.get_variable(
        name='inputs',
        shape=(batch_size, seq_len, bert_config.vocab_size),
        initializer=tf.random_uniform_initializer(-0.1, 0.1))
    t_vars = [logit_inputs]
    t_var_names = {logit_inputs.name}

    logit_inputs += prob_mask
    prob_inputs = tf.nn.softmax(logit_inputs / FLAGS.temp, axis=-1)
    emb_inputs = tf.matmul(prob_inputs, word_emb)

    emb_inputs = tf.concat([cls_emb, emb_inputs, sep_emb], axis=1)
    if FLAGS.low_layer_idx == 0:
        encoded = mean_pool(emb_inputs, input_mask)
    else:
        encoded = encode(emb_inputs, input_ids, input_mask, input_type_ids,
                         bert_config)
    targets = tf.placeholder(tf.float32,
                             shape=(batch_size, encoded.shape.as_list()[-1]))
    loss = get_similarity_metric(encoded, targets, FLAGS.metric, rtn_loss=True)
    loss = tf.reduce_sum(loss)

    if FLAGS.alpha > 0.:
        # encourage the words to be different
        diff = tf.expand_dims(prob_inputs, 2) - tf.expand_dims(prob_inputs, 1)
        reg = tf.reduce_sum(-tf.exp(tf.reduce_sum(diff**2, axis=-1)), [1, 2])
        loss += FLAGS.alpha * tf.reduce_sum(reg)

    optimizer = tf.train.AdamOptimizer(FLAGS.lr)

    start_vars = set(v.name for v in tf.global_variables()
                     if v.name not in t_var_names)
    grads_and_vars = optimizer.compute_gradients(loss, t_vars)
    train_ops = optimizer.apply_gradients(
        grads_and_vars, global_step=tf.train.get_or_create_global_step())

    end_vars = tf.global_variables()
    new_vars = [v for v in end_vars if v.name not in start_vars]

    preds = tf.argmax(prob_inputs, axis=-1)
    batch_init_ops = tf.variables_initializer(new_vars)

    total_it = len(x) // batch_size

    with tf.Session() as sess:
        sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

        def invert_one_batch(batch_targets):
            sess.run(batch_init_ops)
            feed_dict = {targets: batch_targets}
            prev = 1e6
            for i in range(max_iters):
                curr, _ = sess.run([loss, train_ops], feed_dict)
                # stop if no progress
                if (i + 1) % (max_iters // 10) == 0 and curr > prev:
                    break
                prev = curr
            return sess.run([preds, loss], feed_dict)

        start_time = time.time()
        it = 0.0
        all_tp, all_fp, all_fn, all_err = 0.0, 0.0, 0.0, 0.0

        for batch_idx in iterate_minibatches_indices(len(x), batch_size, False,
                                                     False):
            y_pred, err = invert_one_batch(x[batch_idx])
            tp, fp, fn = tp_fp_fn_metrics_np(y_pred, y[batch_idx])

            # for yp, yt in zip(y_pred, y[batch_idx]):
            #   print(','.join(set(tokenizer.convert_ids_to_tokens(yp))))
            #   print(','.join(set(tokenizer.convert_ids_to_tokens(yt))))

            it += 1.0
            all_err += err
            all_tp += tp
            all_fp += fp
            all_fn += fn

            all_pre = all_tp / (all_tp + all_fp + 1e-7)
            all_rec = all_tp / (all_tp + all_fn + 1e-7)
            all_f1 = 2 * all_pre * all_rec / (all_pre + all_rec + 1e-7)

            if it % FLAGS.print_every == 0:
                it_time = (time.time() - start_time) / it
                log("Iter {:.2f}%, err={}, pre={:.2f}%, rec={:.2f}%, f1={:.2f}%,"
                    " {:.2f} sec/it".format(it / total_it * 100, all_err / it,
                                            all_pre * 100, all_rec * 100,
                                            all_f1 * 100, it_time))

        all_pre = all_tp / (all_tp + all_fp + 1e-7)
        all_rec = all_tp / (all_tp + all_fn + 1e-7)
        all_f1 = 2 * all_pre * all_rec / (all_pre + all_rec + 1e-7)
        log("Final err={}, pre={:.2f}%, rec={:.2f}%, f1={:.2f}%".format(
            all_err / it, all_pre * 100, all_rec * 100, all_f1 * 100))