type_one_hot_labels = tf.one_hot(type_label, depth=13, dtype=tf.float32)
type_per_example_loss = -tf.reduce_sum(type_one_hot_labels * type_log_probs,
                                       axis=-1)
type_cls_loss = tf.reduce_mean(type_per_example_loss)

# Connective string Classification Loss
conn_probabilities = tf.nn.softmax(conn_logits, axis=-1)
conn_log_probs = tf.nn.log_softmax(conn_logits, axis=-1)
conn_one_hot_labels = tf.one_hot(conn_label, depth=71, dtype=tf.float32)
conn_per_example_loss = -tf.reduce_sum(conn_one_hot_labels * conn_log_probs,
                                       axis=-1)
conn_cls_loss = tf.reduce_mean(conn_per_example_loss)

# Generation Loss
mle_loss = transformer_utils.smoothing_cross_entropy(outputs.logits, labels,
                                                     vocab_size,
                                                     loss_label_confidence)
mle_loss = tf.reduce_sum(mle_loss * is_target) / tf.reduce_sum(is_target)

type_const_cls = 0.2

conn_const_cls = 0.2

total_loss = mle_loss + (type_const_cls * type_cls_loss) + (conn_const_cls *
                                                            conn_cls_loss)

tvars = tf.trainable_variables()

non_bert_vars = [var for var in tvars if 'bert' not in var.name]
bert_vars = [var for var in tvars if 'bert' in var.name]
enc_vars = [var for var in bert_vars if "/encoder/" in var.name]
Beispiel #2
0
def main():
    """Entrypoint.
    """
    # Load data
    print('Loading data ...')
    train_data, dev_data, test_data = data_utils.load_data_numpy(
        config_data.input_dir, config_data.filename_prefix)
    print('Load data done')
    with open(config_data.vocab_file, 'rb') as f:
        id2w = pickle.load(f)
    vocab_size = len(id2w)
    print('vocab_size {}'.format(vocab_size))
    bos_token_id, eos_token_id = 1, 2

    beam_width = config_model.beam_width

    # Create logging
    tx.utils.maybe_create_dir(FLAGS.model_dir)
    logging_file = os.path.join(FLAGS.model_dir, 'logging.txt')
    logger = utils.get_logger(logging_file)
    print('logging file is saved in: %s', logging_file)

    # Build model graph
    encoder_input = tf.placeholder(tf.int64, shape=(None, None))
    decoder_input = tf.placeholder(tf.int64, shape=(None, None))
    # (text sequence length excluding padding)
    encoder_input_length = tf.reduce_sum(
        1 - tf.to_int32(tf.equal(encoder_input, 0)), axis=1)
    decoder_input_length = tf.reduce_sum(
        1 - tf.to_int32(tf.equal(decoder_input, 0)), axis=1)

    labels = tf.placeholder(tf.int64, shape=(None, None))
    is_target = tf.to_float(tf.not_equal(labels, 0))

    global_step = tf.Variable(0, dtype=tf.int64, trainable=False)
    learning_rate = tf.placeholder(tf.float64, shape=(), name='lr')

    embedder = tx.modules.WordEmbedder(vocab_size=vocab_size,
                                       hparams=config_model.emb)
    encoder = TransformerEncoder(hparams=config_model.encoder)

    encoder_output = encoder(inputs=embedder(encoder_input),
                             sequence_length=encoder_input_length)

    # The decoder ties the input word embedding with the output logit layer.
    # As the decoder masks out <PAD>'s embedding, which in effect means
    # <PAD> has all-zero embedding, so here we explicitly set <PAD>'s embedding
    # to all-zero.
    tgt_embedding = tf.concat(
        [tf.zeros(shape=[1, embedder.dim]), embedder.embedding[1:, :]], axis=0)
    decoder = TransformerDecoder(embedding=tgt_embedding,
                                 hparams=config_model.decoder)
    # For training
    outputs = decoder(memory=encoder_output,
                      memory_sequence_length=encoder_input_length,
                      inputs=embedder(decoder_input),
                      sequence_length=decoder_input_length,
                      decoding_strategy='train_greedy',
                      mode=tf.estimator.ModeKeys.TRAIN)

    mle_loss = transformer_utils.smoothing_cross_entropy(
        outputs.logits, labels, vocab_size, config_model.loss_label_confidence)
    mle_loss = tf.reduce_sum(mle_loss * is_target) / tf.reduce_sum(is_target)

    train_op = tx.core.get_train_op(mle_loss,
                                    learning_rate=learning_rate,
                                    global_step=global_step,
                                    hparams=config_model.opt)

    tf.summary.scalar('lr', learning_rate)
    tf.summary.scalar('mle_loss', mle_loss)
    summary_merged = tf.summary.merge_all()

    # For inference
    start_tokens = tf.fill([tx.utils.get_batch_size(encoder_input)],
                           bos_token_id)
    predictions = decoder(memory=encoder_output,
                          memory_sequence_length=encoder_input_length,
                          decoding_strategy='infer_greedy',
                          beam_width=beam_width,
                          alpha=config_model.alpha,
                          start_tokens=start_tokens,
                          end_token=eos_token_id,
                          max_decoding_length=config_data.max_decoding_length,
                          mode=tf.estimator.ModeKeys.PREDICT)
    if beam_width <= 1:
        inferred_ids = predictions[0].sample_id
    else:
        # Uses the best sample by beam search
        inferred_ids = predictions['sample_id'][:, :, 0]

    saver = tf.train.Saver(max_to_keep=5)
    best_results = {'score': 0, 'epoch': -1}

    def _eval_epoch(sess, epoch, mode):
        if mode == 'eval':
            eval_data = dev_data
        elif mode == 'test':
            eval_data = test_data
        else:
            raise ValueError('`mode` should be either "eval" or "test".')

        references, hypotheses = [], []
        bsize = config_data.test_batch_size
        for i in range(0, len(eval_data), bsize):
            #print("eval {}/{}".format(i, len(eval_data)))
            sources, targets = zip(*eval_data[i:i + bsize])
            x_block = data_utils.source_pad_concat_convert(sources)
            feed_dict = {
                encoder_input: x_block,
                tx.global_mode(): tf.estimator.ModeKeys.EVAL,
            }
            fetches = {
                'inferred_ids': inferred_ids,
            }
            fetches_ = sess.run(fetches, feed_dict=feed_dict)

            hypotheses.extend(h.tolist() for h in fetches_['inferred_ids'])
            references.extend(r.tolist() for r in targets)
            hypotheses = utils.list_strip_eos(hypotheses, eos_token_id)
            references = utils.list_strip_eos(references, eos_token_id)

        if mode == 'eval':
            # Writes results to files to evaluate BLEU
            # For 'eval' mode, the BLEU is based on token ids (rather than
            # text tokens) and serves only as a surrogate metric to monitor
            # the training process
            fname = os.path.join(FLAGS.model_dir, 'tmp.eval')
            hypotheses = tx.utils.str_join(hypotheses)
            references = tx.utils.str_join(references)
            hyp_fn, ref_fn = tx.utils.write_paired_text(hypotheses,
                                                        references,
                                                        fname,
                                                        mode='s')
            eval_bleu = bleu_wrapper(ref_fn, hyp_fn, case_sensitive=True)
            eval_bleu = 100. * eval_bleu
            logger.info('epoch: %d, eval_bleu %.4f', epoch, eval_bleu)
            print('epoch: %d, eval_bleu %.4f' % (epoch, eval_bleu))

            if eval_bleu > best_results['score']:
                logger.info('epoch: %d, best bleu: %.4f', epoch, eval_bleu)
                best_results['score'] = eval_bleu
                best_results['epoch'] = epoch
                model_path = os.path.join(FLAGS.model_dir, 'best-model.ckpt')
                logger.info('saving model to %s', model_path)
                print('saving model to %s' % model_path)
                saver.save(sess, model_path)

        elif mode == 'test':
            # For 'test' mode, together with the cmds in README.md, BLEU
            # is evaluated based on text tokens, which is the standard metric.
            fname = os.path.join(FLAGS.model_dir, 'test.output')
            hwords, rwords = [], []
            for hyp, ref in zip(hypotheses, references):
                hwords.append([id2w[y] for y in hyp])
                rwords.append([id2w[y] for y in ref])
            hwords = tx.utils.str_join(hwords)
            rwords = tx.utils.str_join(rwords)
            hyp_fn, ref_fn = tx.utils.write_paired_text(hwords,
                                                        rwords,
                                                        fname,
                                                        mode='s')
            logger.info('Test output writtn to file: %s', hyp_fn)
            print('Test output writtn to file: %s' % hyp_fn)

    def _train_epoch(sess, epoch, step, smry_writer):
        random.shuffle(train_data)
        train_iter = data.iterator.pool(
            train_data,
            config_data.batch_size,
            key=lambda x: (len(x[0]), len(x[1])),
            batch_size_fn=utils.batch_size_fn,
            random_shuffler=data.iterator.RandomShuffler())

        for _, train_batch in enumerate(train_iter):
            if len(train_batch) == 0:
                continue
            in_arrays = data_utils.seq2seq_pad_concat_convert(train_batch)
            feed_dict = {
                encoder_input: in_arrays[0],
                decoder_input: in_arrays[1],
                labels: in_arrays[2],
                learning_rate: utils.get_lr(step, config_model.lr)
            }
            fetches = {
                'step': global_step,
                'train_op': train_op,
                'smry': summary_merged,
                'loss': mle_loss,
            }

            fetches_ = sess.run(fetches, feed_dict=feed_dict)

            step, loss = fetches_['step'], fetches_['loss']
            if step and step % config_data.display_steps == 0:
                logger.info('step: %d, loss: %.4f', step, loss)
                print('step: %d, loss: %.4f' % (step, loss))
                smry_writer.add_summary(fetches_['smry'], global_step=step)

            if step and step % config_data.eval_steps == 0:
                _eval_epoch(sess, epoch, mode='eval')
        return step

    # Run the graph
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        smry_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph)

        if FLAGS.run_mode == 'train_and_evaluate':
            step = 0
            for epoch in range(config_data.max_train_epoch):
                step = _train_epoch(sess, epoch, step, smry_writer)

        elif FLAGS.run_mode == 'test':
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model_dir))
            _eval_epoch(sess, 0, mode='test')
Beispiel #3
0
def generator(text_ids, text_keyword_id, text_keyword_length, labels,
              text_length, temperature, vocab_size, batch_size, seq_len,
              gen_emb_dim, mem_slots, head_size, num_heads, hidden_dim,
              start_token):

    is_target = tf.to_float(tf.not_equal(text_ids[:, 1:], 0))

    # Source word embedding
    src_word_embedder = tx.modules.WordEmbedder(vocab_size=vocab_size,
                                                hparams=trans_config.emb)
    src_word_embeds = src_word_embedder(text_keyword_id)
    src_word_embeds = src_word_embeds * trans_config.hidden_dim**0.5

    # Position embedding (shared b/w source and target)
    pos_embedder = tx.modules.SinusoidsPositionEmbedder(
        position_size=seq_len, hparams=trans_config.position_embedder_hparams)
    # src_seq_len = batch_data['text_keyword_length']
    src_pos_embeds = pos_embedder(sequence_length=seq_len)

    src_input_embedding = src_word_embeds + src_pos_embeds

    encoder = TransformerEncoder(hparams=trans_config.encoder)
    encoder_output = encoder(inputs=src_input_embedding,
                             sequence_length=text_keyword_length)

    # modify sentiment label
    label_connector = MLPTransformConnector(
        output_size=trans_config.hidden_dim)

    labels = tf.to_float(tf.reshape(labels, [-1, 1]))
    c = tf.reshape(label_connector(labels), [batch_size, 1, 512])
    c_ = tf.reshape(label_connector(1 - labels), [batch_size, 1, 512])
    encoder_output = tf.concat([c, encoder_output[:, 1:, :]], axis=1)
    encoder_output_ = tf.concat([c_, encoder_output[:, 1:, :]], axis=1)

    # The decoder ties the input word embedding with the output logit layer.
    # As the decoder masks out <PAD>'s embedding, which in effect means
    # <PAD> has all-zero embedding, so here we explicitly set <PAD>'s embedding
    # to all-zero.
    tgt_embedding = tf.concat([
        tf.zeros(shape=[1, src_word_embedder.dim]),
        src_word_embedder.embedding[1:, :]
    ],
                              axis=0)
    tgt_embedder = tx.modules.WordEmbedder(tgt_embedding)
    tgt_word_embeds = tgt_embedder(text_ids)
    tgt_word_embeds = tgt_word_embeds * trans_config.hidden_dim**0.5

    tgt_seq_len = text_length
    tgt_pos_embeds = pos_embedder(sequence_length=tgt_seq_len)

    tgt_input_embedding = tgt_word_embeds + tgt_pos_embeds

    _output_w = tf.transpose(tgt_embedder.embedding, (1, 0))

    decoder = TransformerDecoder(vocab_size=vocab_size,
                                 output_layer=_output_w,
                                 hparams=trans_config.decoder)
    # For training
    outputs = decoder(memory=encoder_output,
                      memory_sequence_length=text_keyword_length,
                      inputs=tgt_input_embedding,
                      decoding_strategy='train_greedy',
                      mode=tf.estimator.ModeKeys.TRAIN)

    mle_loss = transformer_utils.smoothing_cross_entropy(
        outputs.logits[:, :-1, :], text_ids[:, 1:], vocab_size,
        trans_config.loss_label_confidence)
    pretrain_loss = tf.reduce_sum(
        mle_loss * is_target) / tf.reduce_sum(is_target)

    # Gumbel-softmax decoding, used in training
    start_tokens = np.ones(batch_size, int)
    end_token = int(2)
    gumbel_helper = GumbelSoftmaxEmbeddingHelper(tgt_embedding, start_tokens,
                                                 end_token, temperature)

    gumbel_outputs, sequence_lengths = decoder(
        memory=encoder_output_,
        memory_sequence_length=text_keyword_length,
        helper=gumbel_helper)

    # max_index = tf.argmax(gumbel_outputs.logits, axis=2)
    # gen_x_onehot_adv = tf.one_hot(max_index, vocab_size, sentiment.1.0, 0.0)

    gen_o = tf.reduce_sum(tf.reduce_max(gumbel_outputs.logits, axis=2))

    return gumbel_outputs.logits, gumbel_outputs.sample_id, pretrain_loss, gen_o