type_one_hot_labels = tf.one_hot(type_label, depth=13, dtype=tf.float32) type_per_example_loss = -tf.reduce_sum(type_one_hot_labels * type_log_probs, axis=-1) type_cls_loss = tf.reduce_mean(type_per_example_loss) # Connective string Classification Loss conn_probabilities = tf.nn.softmax(conn_logits, axis=-1) conn_log_probs = tf.nn.log_softmax(conn_logits, axis=-1) conn_one_hot_labels = tf.one_hot(conn_label, depth=71, dtype=tf.float32) conn_per_example_loss = -tf.reduce_sum(conn_one_hot_labels * conn_log_probs, axis=-1) conn_cls_loss = tf.reduce_mean(conn_per_example_loss) # Generation Loss mle_loss = transformer_utils.smoothing_cross_entropy(outputs.logits, labels, vocab_size, loss_label_confidence) mle_loss = tf.reduce_sum(mle_loss * is_target) / tf.reduce_sum(is_target) type_const_cls = 0.2 conn_const_cls = 0.2 total_loss = mle_loss + (type_const_cls * type_cls_loss) + (conn_const_cls * conn_cls_loss) tvars = tf.trainable_variables() non_bert_vars = [var for var in tvars if 'bert' not in var.name] bert_vars = [var for var in tvars if 'bert' in var.name] enc_vars = [var for var in bert_vars if "/encoder/" in var.name]
def main(): """Entrypoint. """ # Load data print('Loading data ...') train_data, dev_data, test_data = data_utils.load_data_numpy( config_data.input_dir, config_data.filename_prefix) print('Load data done') with open(config_data.vocab_file, 'rb') as f: id2w = pickle.load(f) vocab_size = len(id2w) print('vocab_size {}'.format(vocab_size)) bos_token_id, eos_token_id = 1, 2 beam_width = config_model.beam_width # Create logging tx.utils.maybe_create_dir(FLAGS.model_dir) logging_file = os.path.join(FLAGS.model_dir, 'logging.txt') logger = utils.get_logger(logging_file) print('logging file is saved in: %s', logging_file) # Build model graph encoder_input = tf.placeholder(tf.int64, shape=(None, None)) decoder_input = tf.placeholder(tf.int64, shape=(None, None)) # (text sequence length excluding padding) encoder_input_length = tf.reduce_sum( 1 - tf.to_int32(tf.equal(encoder_input, 0)), axis=1) decoder_input_length = tf.reduce_sum( 1 - tf.to_int32(tf.equal(decoder_input, 0)), axis=1) labels = tf.placeholder(tf.int64, shape=(None, None)) is_target = tf.to_float(tf.not_equal(labels, 0)) global_step = tf.Variable(0, dtype=tf.int64, trainable=False) learning_rate = tf.placeholder(tf.float64, shape=(), name='lr') embedder = tx.modules.WordEmbedder(vocab_size=vocab_size, hparams=config_model.emb) encoder = TransformerEncoder(hparams=config_model.encoder) encoder_output = encoder(inputs=embedder(encoder_input), sequence_length=encoder_input_length) # The decoder ties the input word embedding with the output logit layer. # As the decoder masks out <PAD>'s embedding, which in effect means # <PAD> has all-zero embedding, so here we explicitly set <PAD>'s embedding # to all-zero. tgt_embedding = tf.concat( [tf.zeros(shape=[1, embedder.dim]), embedder.embedding[1:, :]], axis=0) decoder = TransformerDecoder(embedding=tgt_embedding, hparams=config_model.decoder) # For training outputs = decoder(memory=encoder_output, memory_sequence_length=encoder_input_length, inputs=embedder(decoder_input), sequence_length=decoder_input_length, decoding_strategy='train_greedy', mode=tf.estimator.ModeKeys.TRAIN) mle_loss = transformer_utils.smoothing_cross_entropy( outputs.logits, labels, vocab_size, config_model.loss_label_confidence) mle_loss = tf.reduce_sum(mle_loss * is_target) / tf.reduce_sum(is_target) train_op = tx.core.get_train_op(mle_loss, learning_rate=learning_rate, global_step=global_step, hparams=config_model.opt) tf.summary.scalar('lr', learning_rate) tf.summary.scalar('mle_loss', mle_loss) summary_merged = tf.summary.merge_all() # For inference start_tokens = tf.fill([tx.utils.get_batch_size(encoder_input)], bos_token_id) predictions = decoder(memory=encoder_output, memory_sequence_length=encoder_input_length, decoding_strategy='infer_greedy', beam_width=beam_width, alpha=config_model.alpha, start_tokens=start_tokens, end_token=eos_token_id, max_decoding_length=config_data.max_decoding_length, mode=tf.estimator.ModeKeys.PREDICT) if beam_width <= 1: inferred_ids = predictions[0].sample_id else: # Uses the best sample by beam search inferred_ids = predictions['sample_id'][:, :, 0] saver = tf.train.Saver(max_to_keep=5) best_results = {'score': 0, 'epoch': -1} def _eval_epoch(sess, epoch, mode): if mode == 'eval': eval_data = dev_data elif mode == 'test': eval_data = test_data else: raise ValueError('`mode` should be either "eval" or "test".') references, hypotheses = [], [] bsize = config_data.test_batch_size for i in range(0, len(eval_data), bsize): #print("eval {}/{}".format(i, len(eval_data))) sources, targets = zip(*eval_data[i:i + bsize]) x_block = data_utils.source_pad_concat_convert(sources) feed_dict = { encoder_input: x_block, tx.global_mode(): tf.estimator.ModeKeys.EVAL, } fetches = { 'inferred_ids': inferred_ids, } fetches_ = sess.run(fetches, feed_dict=feed_dict) hypotheses.extend(h.tolist() for h in fetches_['inferred_ids']) references.extend(r.tolist() for r in targets) hypotheses = utils.list_strip_eos(hypotheses, eos_token_id) references = utils.list_strip_eos(references, eos_token_id) if mode == 'eval': # Writes results to files to evaluate BLEU # For 'eval' mode, the BLEU is based on token ids (rather than # text tokens) and serves only as a surrogate metric to monitor # the training process fname = os.path.join(FLAGS.model_dir, 'tmp.eval') hypotheses = tx.utils.str_join(hypotheses) references = tx.utils.str_join(references) hyp_fn, ref_fn = tx.utils.write_paired_text(hypotheses, references, fname, mode='s') eval_bleu = bleu_wrapper(ref_fn, hyp_fn, case_sensitive=True) eval_bleu = 100. * eval_bleu logger.info('epoch: %d, eval_bleu %.4f', epoch, eval_bleu) print('epoch: %d, eval_bleu %.4f' % (epoch, eval_bleu)) if eval_bleu > best_results['score']: logger.info('epoch: %d, best bleu: %.4f', epoch, eval_bleu) best_results['score'] = eval_bleu best_results['epoch'] = epoch model_path = os.path.join(FLAGS.model_dir, 'best-model.ckpt') logger.info('saving model to %s', model_path) print('saving model to %s' % model_path) saver.save(sess, model_path) elif mode == 'test': # For 'test' mode, together with the cmds in README.md, BLEU # is evaluated based on text tokens, which is the standard metric. fname = os.path.join(FLAGS.model_dir, 'test.output') hwords, rwords = [], [] for hyp, ref in zip(hypotheses, references): hwords.append([id2w[y] for y in hyp]) rwords.append([id2w[y] for y in ref]) hwords = tx.utils.str_join(hwords) rwords = tx.utils.str_join(rwords) hyp_fn, ref_fn = tx.utils.write_paired_text(hwords, rwords, fname, mode='s') logger.info('Test output writtn to file: %s', hyp_fn) print('Test output writtn to file: %s' % hyp_fn) def _train_epoch(sess, epoch, step, smry_writer): random.shuffle(train_data) train_iter = data.iterator.pool( train_data, config_data.batch_size, key=lambda x: (len(x[0]), len(x[1])), batch_size_fn=utils.batch_size_fn, random_shuffler=data.iterator.RandomShuffler()) for _, train_batch in enumerate(train_iter): if len(train_batch) == 0: continue in_arrays = data_utils.seq2seq_pad_concat_convert(train_batch) feed_dict = { encoder_input: in_arrays[0], decoder_input: in_arrays[1], labels: in_arrays[2], learning_rate: utils.get_lr(step, config_model.lr) } fetches = { 'step': global_step, 'train_op': train_op, 'smry': summary_merged, 'loss': mle_loss, } fetches_ = sess.run(fetches, feed_dict=feed_dict) step, loss = fetches_['step'], fetches_['loss'] if step and step % config_data.display_steps == 0: logger.info('step: %d, loss: %.4f', step, loss) print('step: %d, loss: %.4f' % (step, loss)) smry_writer.add_summary(fetches_['smry'], global_step=step) if step and step % config_data.eval_steps == 0: _eval_epoch(sess, epoch, mode='eval') return step # Run the graph with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) smry_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph) if FLAGS.run_mode == 'train_and_evaluate': step = 0 for epoch in range(config_data.max_train_epoch): step = _train_epoch(sess, epoch, step, smry_writer) elif FLAGS.run_mode == 'test': saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model_dir)) _eval_epoch(sess, 0, mode='test')
def generator(text_ids, text_keyword_id, text_keyword_length, labels, text_length, temperature, vocab_size, batch_size, seq_len, gen_emb_dim, mem_slots, head_size, num_heads, hidden_dim, start_token): is_target = tf.to_float(tf.not_equal(text_ids[:, 1:], 0)) # Source word embedding src_word_embedder = tx.modules.WordEmbedder(vocab_size=vocab_size, hparams=trans_config.emb) src_word_embeds = src_word_embedder(text_keyword_id) src_word_embeds = src_word_embeds * trans_config.hidden_dim**0.5 # Position embedding (shared b/w source and target) pos_embedder = tx.modules.SinusoidsPositionEmbedder( position_size=seq_len, hparams=trans_config.position_embedder_hparams) # src_seq_len = batch_data['text_keyword_length'] src_pos_embeds = pos_embedder(sequence_length=seq_len) src_input_embedding = src_word_embeds + src_pos_embeds encoder = TransformerEncoder(hparams=trans_config.encoder) encoder_output = encoder(inputs=src_input_embedding, sequence_length=text_keyword_length) # modify sentiment label label_connector = MLPTransformConnector( output_size=trans_config.hidden_dim) labels = tf.to_float(tf.reshape(labels, [-1, 1])) c = tf.reshape(label_connector(labels), [batch_size, 1, 512]) c_ = tf.reshape(label_connector(1 - labels), [batch_size, 1, 512]) encoder_output = tf.concat([c, encoder_output[:, 1:, :]], axis=1) encoder_output_ = tf.concat([c_, encoder_output[:, 1:, :]], axis=1) # The decoder ties the input word embedding with the output logit layer. # As the decoder masks out <PAD>'s embedding, which in effect means # <PAD> has all-zero embedding, so here we explicitly set <PAD>'s embedding # to all-zero. tgt_embedding = tf.concat([ tf.zeros(shape=[1, src_word_embedder.dim]), src_word_embedder.embedding[1:, :] ], axis=0) tgt_embedder = tx.modules.WordEmbedder(tgt_embedding) tgt_word_embeds = tgt_embedder(text_ids) tgt_word_embeds = tgt_word_embeds * trans_config.hidden_dim**0.5 tgt_seq_len = text_length tgt_pos_embeds = pos_embedder(sequence_length=tgt_seq_len) tgt_input_embedding = tgt_word_embeds + tgt_pos_embeds _output_w = tf.transpose(tgt_embedder.embedding, (1, 0)) decoder = TransformerDecoder(vocab_size=vocab_size, output_layer=_output_w, hparams=trans_config.decoder) # For training outputs = decoder(memory=encoder_output, memory_sequence_length=text_keyword_length, inputs=tgt_input_embedding, decoding_strategy='train_greedy', mode=tf.estimator.ModeKeys.TRAIN) mle_loss = transformer_utils.smoothing_cross_entropy( outputs.logits[:, :-1, :], text_ids[:, 1:], vocab_size, trans_config.loss_label_confidence) pretrain_loss = tf.reduce_sum( mle_loss * is_target) / tf.reduce_sum(is_target) # Gumbel-softmax decoding, used in training start_tokens = np.ones(batch_size, int) end_token = int(2) gumbel_helper = GumbelSoftmaxEmbeddingHelper(tgt_embedding, start_tokens, end_token, temperature) gumbel_outputs, sequence_lengths = decoder( memory=encoder_output_, memory_sequence_length=text_keyword_length, helper=gumbel_helper) # max_index = tf.argmax(gumbel_outputs.logits, axis=2) # gen_x_onehot_adv = tf.one_hot(max_index, vocab_size, sentiment.1.0, 0.0) gen_o = tf.reduce_sum(tf.reduce_max(gumbel_outputs.logits, axis=2)) return gumbel_outputs.logits, gumbel_outputs.sample_id, pretrain_loss, gen_o