def model_fn_builder(albert_config, init_checkpoint, use_one_hot_embeddings):
  """Returns `model_fn` closure for TPUEstimator."""

  input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length],
                             name='input_ids')
  input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length],
                              name='input_mask')
  input_type_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length],
                                  name='segment_ids')
  model = modeling.AlbertModel(
      config=albert_config,
      is_training=False,
      input_ids=input_ids,
      input_mask=input_mask,
      token_type_ids=input_type_ids,
      use_one_hot_embeddings=use_one_hot_embeddings)

  all_layer_outputs = [model.get_word_embedding_output()]
  all_layer_outputs += model.get_all_encoder_layers()
  if FLAGS.high_layer_idx == FLAGS.low_layer_idx:
    if FLAGS.use_cls_token:
      outputs = model.get_pooled_output()
    else:
      outputs = all_layer_outputs[FLAGS.high_layer_idx]
      outputs = mean_pool(outputs, input_mask)
  else:
    low_outputs = all_layer_outputs[FLAGS.low_layer_idx]
    low_outputs = mean_pool(low_outputs, input_mask)
    if FLAGS.use_cls_token:
      high_outputs = model.get_pooled_output()
    else:
      high_outputs = all_layer_outputs[FLAGS.high_layer_idx]
      high_outputs = mean_pool(high_outputs, input_mask)
    outputs = (low_outputs, high_outputs)

  tvars = tf.trainable_variables()
  (assignment_map,
   initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
       tvars, init_checkpoint)

  tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
  return input_ids, input_mask, input_type_ids, outputs
Пример #2
0
def encode(embedding_output, input_ids, input_mask, token_type_ids, config):
    with tf.variable_scope("bert", reuse=True):
        with tf.variable_scope("embeddings", reuse=True):
            embedding_output = modeling.embedding_postprocessor(
                input_tensor=embedding_output,
                use_token_type=True,
                token_type_ids=token_type_ids,
                token_type_vocab_size=config.type_vocab_size,
                token_type_embedding_name="token_type_embeddings",
                use_position_embeddings=True,
                position_embedding_name="position_embeddings",
                initializer_range=config.initializer_range,
                max_position_embeddings=config.max_position_embeddings,
                dropout_prob=config.hidden_dropout_prob)

        with tf.variable_scope("encoder", reuse=True):
            attention_mask = modeling.create_attention_mask_from_input_mask(
                input_ids, input_mask)

            all_encoder_layers, _ = modeling.transformer_model(
                input_tensor=embedding_output,
                attention_mask=attention_mask,
                hidden_size=config.hidden_size,
                num_hidden_layers=config.num_hidden_layers,
                num_attention_heads=config.num_attention_heads,
                intermediate_size=config.intermediate_size,
                intermediate_act_fn=modeling.get_activation(config.hidden_act),
                hidden_dropout_prob=config.hidden_dropout_prob,
                attention_probs_dropout_prob=config.
                attention_probs_dropout_prob,
                initializer_range=config.initializer_range,
                do_return_all_layers=True)

        all_encoder_layers = [embedding_output] + all_encoder_layers
        if FLAGS.use_cls_token:
            with tf.variable_scope("pooler", reuse=True):
                first_token_tensor = tf.squeeze(
                    all_encoder_layers[-1][:, 0:1, :], 1)
                pooled_output = tf.layers.dense(
                    first_token_tensor,
                    config.hidden_size,
                    activation=tf.tanh,
                    kernel_initializer=modeling.create_initializer(
                        config.initializer_range))
        else:
            sequence_output = all_encoder_layers[FLAGS.low_layer_idx]
            pooled_output = mean_pool(sequence_output, input_mask)
    return pooled_output
def optimization_inversion():
  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case,
      spm_model_file=FLAGS.spm_model_file)
  cls_id = tokenizer.vocab['[CLS]']
  sep_id = tokenizer.vocab['[SEP]']
  mask_id = tokenizer.vocab['[MASK]']

  _, _, x, y = load_inversion_data()
  y = y[0]
  filters = [cls_id, sep_id, mask_id]
  y = filter_labels(y, filters)

  batch_size = FLAGS.batch_size
  seq_len = FLAGS.seq_len
  max_iters = FLAGS.max_iters

  albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)
  input_ids = tf.ones((batch_size, seq_len + 2), tf.int32)
  input_mask = tf.ones_like(input_ids, tf.int32)
  input_type_ids = tf.zeros_like(input_ids, tf.int32)

  model = modeling.AlbertModel(
      config=albert_config,
      is_training=False,
      input_ids=input_ids,
      input_mask=input_mask,
      token_type_ids=input_type_ids,
      use_one_hot_embeddings=False)

  word_emb = model.output_embedding_table

  albert_vars = tf.trainable_variables()
  (assignment_map,
   _) = modeling.get_assignment_map_from_checkpoint(albert_vars,
                                                    FLAGS.init_checkpoint)
  tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map)

  batch_cls_ids = tf.ones((batch_size, 1), tf.int32) * cls_id
  batch_sep_ids = tf.ones((batch_size, 1), tf.int32) * sep_id
  cls_emb = tf.nn.embedding_lookup(word_emb, batch_cls_ids)
  sep_emb = tf.nn.embedding_lookup(word_emb, batch_sep_ids)

  prob_mask = np.zeros((albert_config.vocab_size,), np.float32)
  prob_mask[filters] = -1e9
  prob_mask = tf.constant(prob_mask, dtype=np.float32)

  logit_inputs = tf.get_variable(
      name='inputs',
      shape=(batch_size, seq_len, albert_config.vocab_size),
      initializer=tf.random_uniform_initializer(-0.1, 0.1))

  t_vars = [logit_inputs]
  t_var_names = {logit_inputs.name}

  logit_inputs += prob_mask
  prob_inputs = tf.nn.softmax(logit_inputs / FLAGS.temp, axis=-1)

  emb_inputs = tf.matmul(prob_inputs, word_emb)
  emb_inputs = tf.concat([cls_emb, emb_inputs, sep_emb], axis=1)

  if FLAGS.low_layer_idx == 0:
    encoded = mean_pool(emb_inputs, input_mask)
  else:
    encoded = encode(emb_inputs, input_mask, input_type_ids, albert_config)
  targets = tf.placeholder(
        tf.float32, shape=(batch_size, encoded.shape.as_list()[-1]))

  loss = get_similarity_metric(encoded, targets, FLAGS.metric, rtn_loss=True)
  loss = tf.reduce_sum(loss)

  optimizer = tf.train.AdamOptimizer(FLAGS.lr)

  start_vars = set(v.name for v in tf.global_variables()
                   if v.name not in t_var_names)
  grads_and_vars = optimizer.compute_gradients(loss, t_vars)
  train_ops = optimizer.apply_gradients(
      grads_and_vars, global_step=tf.train.get_or_create_global_step())

  end_vars = tf.global_variables()
  new_vars = [v for v in end_vars if v.name not in start_vars]

  preds = tf.argmax(prob_inputs, axis=-1)
  batch_init_ops = tf.variables_initializer(new_vars)

  total_it = len(x) // batch_size
  with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

    def invert_one_batch(batch_targets):
      sess.run(batch_init_ops)
      feed_dict = {targets: batch_targets}
      prev = 1e6
      for i in range(max_iters):
        curr, _ = sess.run([loss, train_ops], feed_dict)
        # stop if no progress
        if (i + 1) % (max_iters // 10) == 0 and curr > prev:
          break
        prev = curr
      return sess.run([preds, loss], feed_dict)

    start_time = time.time()
    it = 0.0
    all_tp, all_fp, all_fn, all_err = 0.0, 0.0, 0.0, 0.0

    for batch_idx in iterate_minibatches_indices(len(x), batch_size,
                                                 False, False):
      y_pred, err = invert_one_batch(x[batch_idx])
      tp, fp, fn = tp_fp_fn_metrics_np(y_pred, y[batch_idx])
      # for yp, yt in zip(y_pred, y[batch_idx]):
      #   print(' '.join(set(tokenizer.convert_ids_to_tokens(yp))))
      #   print(' '.join(set(tokenizer.convert_ids_to_tokens(yt))))

      it += 1.0
      all_err += err
      all_tp += tp
      all_fp += fp
      all_fn += fn

      all_pre = all_tp / (all_tp + all_fp + 1e-7)
      all_rec = all_tp / (all_tp + all_fn + 1e-7)
      all_f1 = 2 * all_pre * all_rec / (all_pre + all_rec + 1e-7)

      if it % FLAGS.print_every == 0:
        it_time = (time.time() - start_time) / it
        log("Iter {:.2f}%, err={}, pre={:.2f}%, rec={:.2f}%, f1={:.2f}%,"
            " {:.2f} sec/it".format(it / total_it * 100, all_err / it,
                                    all_pre * 100, all_rec * 100,
                                    all_f1 * 100, it_time))

    all_pre = all_tp / (all_tp + all_fp + 1e-7)
    all_rec = all_tp / (all_tp + all_fn + 1e-7)
    all_f1 = 2 * all_pre * all_rec / (all_pre + all_rec + 1e-7)
    log("Final err={}, pre={:.2f}%, rec={:.2f}%, f1={:.2f}%".format(
      all_err / it, all_pre * 100, all_rec * 100, all_f1 * 100))