def model_fn_builder(albert_config, init_checkpoint, use_one_hot_embeddings): """Returns `model_fn` closure for TPUEstimator.""" input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids') input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask') input_type_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids') model = modeling.AlbertModel( config=albert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, token_type_ids=input_type_ids, use_one_hot_embeddings=use_one_hot_embeddings) all_layer_outputs = [model.get_word_embedding_output()] all_layer_outputs += model.get_all_encoder_layers() if FLAGS.high_layer_idx == FLAGS.low_layer_idx: if FLAGS.use_cls_token: outputs = model.get_pooled_output() else: outputs = all_layer_outputs[FLAGS.high_layer_idx] outputs = mean_pool(outputs, input_mask) else: low_outputs = all_layer_outputs[FLAGS.low_layer_idx] low_outputs = mean_pool(low_outputs, input_mask) if FLAGS.use_cls_token: high_outputs = model.get_pooled_output() else: high_outputs = all_layer_outputs[FLAGS.high_layer_idx] high_outputs = mean_pool(high_outputs, input_mask) outputs = (low_outputs, high_outputs) tvars = tf.trainable_variables() (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return input_ids, input_mask, input_type_ids, outputs
def encode(embedding_output, input_ids, input_mask, token_type_ids, config): with tf.variable_scope("bert", reuse=True): with tf.variable_scope("embeddings", reuse=True): embedding_output = modeling.embedding_postprocessor( input_tensor=embedding_output, use_token_type=True, token_type_ids=token_type_ids, token_type_vocab_size=config.type_vocab_size, token_type_embedding_name="token_type_embeddings", use_position_embeddings=True, position_embedding_name="position_embeddings", initializer_range=config.initializer_range, max_position_embeddings=config.max_position_embeddings, dropout_prob=config.hidden_dropout_prob) with tf.variable_scope("encoder", reuse=True): attention_mask = modeling.create_attention_mask_from_input_mask( input_ids, input_mask) all_encoder_layers, _ = modeling.transformer_model( input_tensor=embedding_output, attention_mask=attention_mask, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, intermediate_act_fn=modeling.get_activation(config.hidden_act), hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=config. attention_probs_dropout_prob, initializer_range=config.initializer_range, do_return_all_layers=True) all_encoder_layers = [embedding_output] + all_encoder_layers if FLAGS.use_cls_token: with tf.variable_scope("pooler", reuse=True): first_token_tensor = tf.squeeze( all_encoder_layers[-1][:, 0:1, :], 1) pooled_output = tf.layers.dense( first_token_tensor, config.hidden_size, activation=tf.tanh, kernel_initializer=modeling.create_initializer( config.initializer_range)) else: sequence_output = all_encoder_layers[FLAGS.low_layer_idx] pooled_output = mean_pool(sequence_output, input_mask) return pooled_output
def optimization_inversion(): tokenizer = tokenization.FullTokenizer( vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case, spm_model_file=FLAGS.spm_model_file) cls_id = tokenizer.vocab['[CLS]'] sep_id = tokenizer.vocab['[SEP]'] mask_id = tokenizer.vocab['[MASK]'] _, _, x, y = load_inversion_data() y = y[0] filters = [cls_id, sep_id, mask_id] y = filter_labels(y, filters) batch_size = FLAGS.batch_size seq_len = FLAGS.seq_len max_iters = FLAGS.max_iters albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file) input_ids = tf.ones((batch_size, seq_len + 2), tf.int32) input_mask = tf.ones_like(input_ids, tf.int32) input_type_ids = tf.zeros_like(input_ids, tf.int32) model = modeling.AlbertModel( config=albert_config, is_training=False, input_ids=input_ids, input_mask=input_mask, token_type_ids=input_type_ids, use_one_hot_embeddings=False) word_emb = model.output_embedding_table albert_vars = tf.trainable_variables() (assignment_map, _) = modeling.get_assignment_map_from_checkpoint(albert_vars, FLAGS.init_checkpoint) tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map) batch_cls_ids = tf.ones((batch_size, 1), tf.int32) * cls_id batch_sep_ids = tf.ones((batch_size, 1), tf.int32) * sep_id cls_emb = tf.nn.embedding_lookup(word_emb, batch_cls_ids) sep_emb = tf.nn.embedding_lookup(word_emb, batch_sep_ids) prob_mask = np.zeros((albert_config.vocab_size,), np.float32) prob_mask[filters] = -1e9 prob_mask = tf.constant(prob_mask, dtype=np.float32) logit_inputs = tf.get_variable( name='inputs', shape=(batch_size, seq_len, albert_config.vocab_size), initializer=tf.random_uniform_initializer(-0.1, 0.1)) t_vars = [logit_inputs] t_var_names = {logit_inputs.name} logit_inputs += prob_mask prob_inputs = tf.nn.softmax(logit_inputs / FLAGS.temp, axis=-1) emb_inputs = tf.matmul(prob_inputs, word_emb) emb_inputs = tf.concat([cls_emb, emb_inputs, sep_emb], axis=1) if FLAGS.low_layer_idx == 0: encoded = mean_pool(emb_inputs, input_mask) else: encoded = encode(emb_inputs, input_mask, input_type_ids, albert_config) targets = tf.placeholder( tf.float32, shape=(batch_size, encoded.shape.as_list()[-1])) loss = get_similarity_metric(encoded, targets, FLAGS.metric, rtn_loss=True) loss = tf.reduce_sum(loss) optimizer = tf.train.AdamOptimizer(FLAGS.lr) start_vars = set(v.name for v in tf.global_variables() if v.name not in t_var_names) grads_and_vars = optimizer.compute_gradients(loss, t_vars) train_ops = optimizer.apply_gradients( grads_and_vars, global_step=tf.train.get_or_create_global_step()) end_vars = tf.global_variables() new_vars = [v for v in end_vars if v.name not in start_vars] preds = tf.argmax(prob_inputs, axis=-1) batch_init_ops = tf.variables_initializer(new_vars) total_it = len(x) // batch_size with tf.Session() as sess: sess.run([tf.global_variables_initializer(), tf.tables_initializer()]) def invert_one_batch(batch_targets): sess.run(batch_init_ops) feed_dict = {targets: batch_targets} prev = 1e6 for i in range(max_iters): curr, _ = sess.run([loss, train_ops], feed_dict) # stop if no progress if (i + 1) % (max_iters // 10) == 0 and curr > prev: break prev = curr return sess.run([preds, loss], feed_dict) start_time = time.time() it = 0.0 all_tp, all_fp, all_fn, all_err = 0.0, 0.0, 0.0, 0.0 for batch_idx in iterate_minibatches_indices(len(x), batch_size, False, False): y_pred, err = invert_one_batch(x[batch_idx]) tp, fp, fn = tp_fp_fn_metrics_np(y_pred, y[batch_idx]) # for yp, yt in zip(y_pred, y[batch_idx]): # print(' '.join(set(tokenizer.convert_ids_to_tokens(yp)))) # print(' '.join(set(tokenizer.convert_ids_to_tokens(yt)))) it += 1.0 all_err += err all_tp += tp all_fp += fp all_fn += fn all_pre = all_tp / (all_tp + all_fp + 1e-7) all_rec = all_tp / (all_tp + all_fn + 1e-7) all_f1 = 2 * all_pre * all_rec / (all_pre + all_rec + 1e-7) if it % FLAGS.print_every == 0: it_time = (time.time() - start_time) / it log("Iter {:.2f}%, err={}, pre={:.2f}%, rec={:.2f}%, f1={:.2f}%," " {:.2f} sec/it".format(it / total_it * 100, all_err / it, all_pre * 100, all_rec * 100, all_f1 * 100, it_time)) all_pre = all_tp / (all_tp + all_fp + 1e-7) all_rec = all_tp / (all_tp + all_fn + 1e-7) all_f1 = 2 * all_pre * all_rec / (all_pre + all_rec + 1e-7) log("Final err={}, pre={:.2f}%, rec={:.2f}%, f1={:.2f}%".format( all_err / it, all_pre * 100, all_rec * 100, all_f1 * 100))