Ejemplo n.º 1
0
def predict(test_post_file,
            max_len,
            vocab,
            rev_vocab,
            word_embeddings,
            encoder,
            generator,
            output_file=None):
    # data generator
    test_data_generator = batcher(1, test_post_file, response_file=None)

    if output_file:
        fo = open(output_file, 'wb')

    while True:
        try:
            post_sentence = test_data_generator.next()
        except StopIteration:
            logger.info('---------------------finish-------------------------')
            break

        post_ids = [sentence2id(sent, vocab) for sent in post_sentence]
        posts_var, posts_length = padding_inputs(post_ids, None)
        if USE_CUDA:
            posts_var = posts_var.cuda()

        embedded_post = word_embeddings(posts_var)
        _, dec_init_state = encoder(embedded_post,
                                    input_lengths=posts_length.numpy())
        log_softmax_outputs = generator.inference(
            dec_init_state, word_embeddings)  # [B, T, vocab_size]

        hyps, _ = beam_search(dec_init_state,
                              max_len,
                              word_embeddings,
                              generator,
                              beam=5,
                              penalty=1.0,
                              nbest=1)
        results = []
        for h in hyps:
            results.append(id2sentence(h[0], rev_vocab))

        print('*******************************************************')
        print "post:" + ''.join(post_sentence[0])
        print "response:\n" + '\n'.join([''.join(r) for r in results])
        print
Ejemplo n.º 2
0
def adversarial():
    # user the root logger
    logger = logging.getLogger("lan2720")
    
    argparser = argparse.ArgumentParser(add_help=False)
    argparser.add_argument('--load_path', '-p', type=str, required=True)
    # TODO: load best
    argparser.add_argument('--load_epoch', '-e', type=int, required=True)
    
    argparser.add_argument('--filter_num', type=int, required=True)
    argparser.add_argument('--filter_sizes', type=str, required=True)

    argparser.add_argument('--training_ratio', type=int, default=2)
    argparser.add_argument('--g_learning_rate', '-glr', type=float, default=0.001)
    argparser.add_argument('--d_learning_rate', '-dlr', type=float, default=0.001)
    argparser.add_argument('--batch_size', '-b', type=int, default=168)
    
    # new arguments used in adversarial
    new_args = argparser.parse_args()
    
    # load default arguments
    default_arg_file = os.path.join(new_args.load_path, 'args.pkl')
    if not os.path.exists(default_arg_file):
        raise RuntimeError('No default argument file in %s' % new_args.load_path)
    else:
        with open(default_arg_file, 'rb') as f:
            args = pickle.load(f)
    
    args.mode = 'adversarial'
    #args.d_learning_rate  = 0.0001
    args.print_every = 1
    args.g_learning_rate = new_args.g_learning_rate
    args.d_learning_rate = new_args.d_learning_rate
    args.batch_size = new_args.batch_size

    # add new arguments
    args.load_path = new_args.load_path
    args.load_epoch = new_args.load_epoch
    args.filter_num = new_args.filter_num
    args.filter_sizes = new_args.filter_sizes
    args.training_ratio = new_args.training_ratio
    


    # set up the output directory
    exp_dirname = os.path.join(args.exp_dir, args.mode, time.strftime("%Y-%m-%d-%H-%M-%S"))
    os.makedirs(exp_dirname)

    # set up the logger
    tqdm_logging.config(logger, os.path.join(exp_dirname, 'adversarial.log'), 
                        mode='w', silent=False, debug=True)

    # load vocabulary
    vocab, rev_vocab = load_vocab(args.vocab_file, max_vocab=args.max_vocab_size)

    vocab_size = len(vocab)

    word_embeddings = nn.Embedding(vocab_size, args.emb_dim, padding_idx=SYM_PAD)
    E = EncoderRNN(vocab_size, args.emb_dim, args.hidden_dim, args.n_layers, args.dropout_rate, bidirectional=True, variable_lengths=True)
    G = Generator(vocab_size, args.response_max_len, args.emb_dim, 2*args.hidden_dim, args.n_layers, dropout_p=args.dropout_rate)
    D = Discriminator(args.emb_dim, args.filter_num, eval(args.filter_sizes))
    
    if args.use_cuda:
        word_embeddings.cuda()
        E.cuda()
        G.cuda()
        D.cuda()

    # define optimizer
    opt_G = torch.optim.Adam(G.rnn.parameters(), lr=args.g_learning_rate)
    opt_D = torch.optim.Adam(D.parameters(), lr=args.d_learning_rate)
    
    logger.info('----------------------------------')
    logger.info('Adversarial a neural conversation model')
    logger.info('----------------------------------')

    logger.info('Args:')
    logger.info(str(args))
    
    logger.info('Vocabulary from ' + args.vocab_file)
    logger.info('vocabulary size: %d' % vocab_size)
    logger.info('Loading text data from ' + args.train_query_file + ' and ' + args.train_response_file)
   
    
    reload_model(args.load_path, args.load_epoch, word_embeddings, E, G)
    #    start_epoch = args.resume_epoch + 1
    #else:
    #    start_epoch = 0

    # dump args
    with open(os.path.join(exp_dirname, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)


    # TODO: num_epoch is old one
    for e in range(args.num_epoch):
        train_data_generator = batcher(args.batch_size, args.train_query_file, args.train_response_file)
        logger.info("Epoch: %d/%d" % (e, args.num_epoch))
        step = 0
        cur_time = time.time() 
        while True:
            try:
                post_sentences, response_sentences = train_data_generator.next()
            except StopIteration:
                # save model
                save_model(exp_dirname, e, word_embeddings, E, G, D) 
                ## evaluation
                #eval(args.valid_query_file, args.valid_response_file, args.batch_size, 
                #        word_embeddings, E, G, loss_func, args.use_cuda, vocab, args.response_max_len)
                break
            
            # prepare data
            post_ids = [sentence2id(sent, vocab) for sent in post_sentences]
            response_ids = [sentence2id(sent, vocab) for sent in response_sentences]
            posts_var, posts_length = padding_inputs(post_ids, None)
            responses_var, responses_length = padding_inputs(response_ids, args.response_max_len)
            # sort by post length
            posts_length, perms_idx = posts_length.sort(0, descending=True)
            posts_var = posts_var[perms_idx]
            responses_var = responses_var[perms_idx]
            responses_length = responses_length[perms_idx]

            if args.use_cuda:
                posts_var = posts_var.cuda()
                responses_var = responses_var.cuda()

            embedded_post = word_embeddings(posts_var)
            real_responses = word_embeddings(responses_var)

            # forward
            _, dec_init_state = E(embedded_post, input_lengths=posts_length.numpy())
            fake_responses = G(dec_init_state, word_embeddings) # [B, T, emb_size]

            prob_real = D(embedded_post, real_responses)
            prob_fake = D(embedded_post, fake_responses)
        
            # loss
            D_loss = - torch.mean(torch.log(prob_real) + torch.log(1. - prob_fake)) 
            G_loss = torch.mean(torch.log(1. - prob_fake))
            
            if step % args.training_ratio == 0:
                opt_D.zero_grad()
                D_loss.backward(retain_graph=True)
                opt_D.step()
            
            opt_G.zero_grad()
            G_loss.backward()
            opt_G.step()
            
            if step % args.print_every == 0:
                logger.info('Step %5d: D accuracy=%.2f (0.5 for D to converge) D score=%.2f (-1.38 for G to converge) (%.1f iters/sec)' % (
                    step, 
                    prob_real.cpu().data.numpy().mean(), 
                    -D_loss.cpu().data.numpy()[0], 
                    args.print_every/(time.time()-cur_time)))
                cur_time = time.time()
            step = step + 1
Ejemplo n.º 3
0
    console.setFormatter(fmt)
    logger.addHandler(console)
    logfile = logging.FileHandler(args.log_file, 'a')

    logfile.setFormatter(fmt)
    logger.addHandler(logfile)

    if args.test:
        test_graph = load_data(args.test_file)
        model = torch.load(args.load_model)
        model.device = device

        dev_dataset = QBLINKDataset(test_graph, model, False)
        dev_loader = DataLoader(dataset=dev_dataset,
                                batch_size=1,
                                collate_fn=batcher(device),
                                shuffle=False,
                                num_workers=0)
        model.to(device)
        score, total_list = evaluate(dev_loader, model)
        exit()

    train_graph = load_data(args.train_file)
    dev_graph = load_data(args.dev_file)

    word_dict = load_words(args, train_graph)
    model = Model(args, word_dict)
    model.device = device
    model.load_embeddings(word_dict.tokens(), args.embedding_file)

    train_dataset = QBLINKDataset(train_graph, model, True)
Ejemplo n.º 4
0
def eval(valid_query_file, valid_response_file, batch_size, word_embeddings, E,
         G, loss_func, use_cuda, vocab, response_max_len):
    logger.info('---------------------validating--------------------------')
    logger.info('Loading valid data from %s and %s' %
                (valid_query_file, valid_response_file))

    valid_data_generator = batcher(batch_size, valid_query_file,
                                   valid_response_file)

    sum_loss = 0.0
    valid_char_num = 0
    example_num = 0
    while True:
        try:
            post_sentences, response_sentences = valid_data_generator.next()
        except StopIteration:
            # one epoch finish
            break

        post_ids = [sentence2id(sent, vocab) for sent in post_sentences]
        response_ids = [
            sentence2id(sent, vocab) for sent in response_sentences
        ]
        posts_var, posts_length = padding_inputs(post_ids, None)
        responses_var, responses_length = padding_inputs(
            response_ids, response_max_len)

        # sort by post length
        posts_length, perms_idx = posts_length.sort(0, descending=True)
        posts_var = posts_var[perms_idx]
        responses_var = responses_var[perms_idx]
        responses_length = responses_length[perms_idx]

        # 在sentence后面加eos
        references_var = torch.cat([
            responses_var,
            Variable(torch.zeros(responses_var.size(0), 1).long(),
                     requires_grad=False)
        ],
                                   dim=1)
        for idx, length in enumerate(responses_length):
            references_var[idx, length] = SYM_EOS

        if use_cuda:
            posts_var = posts_var.cuda()
            responses_var = responses_var.cuda()
            references_var = references_var.cuda()

        embedded_post = word_embeddings(posts_var)
        embedded_response = word_embeddings(responses_var)
        _, dec_init_state = E(embedded_post,
                              input_lengths=posts_length.numpy())
        log_softmax_outputs = G.supervise(
            embedded_response, dec_init_state,
            word_embeddings)  # [B, T, vocab_size]

        outputs = log_softmax_outputs.view(-1, len(vocab))
        mask_pos = mask(references_var).view(-1).unsqueeze(-1)
        masked_output = outputs * (mask_pos.expand_as(outputs))
        loss = loss_func(masked_output, references_var.view(-1))

        sum_loss += loss.cpu().data.numpy()[0]
        example_num += posts_var.size(0)
        valid_char_num += torch.sum(mask_pos).cpu().data.numpy()[0]

    logger.info(
        'Valid Loss (per case): %.2f Valid Perplexity (per word): %.2f' %
        (sum_loss / example_num, math.exp(sum_loss / valid_char_num)))
    logger.info('---------------------finish-------------------------')
Ejemplo n.º 5
0
def pretrain():
    # Parse command line arguments
    argparser = argparse.ArgumentParser()

    # train
    argparser.add_argument('--mode',
                           '-m',
                           choices=('pretrain', 'adversarial', 'inference'),
                           type=str,
                           required=True)
    argparser.add_argument('--batch_size', '-b', type=int, default=168)
    argparser.add_argument('--num_epoch', '-e', type=int, default=10)
    argparser.add_argument('--print_every', type=int, default=100)
    argparser.add_argument('--use_cuda', default=True)
    argparser.add_argument('--g_learning_rate',
                           '-glr',
                           type=float,
                           default=0.001)
    argparser.add_argument('--d_learning_rate',
                           '-dlr',
                           type=float,
                           default=0.001)

    # resume
    argparser.add_argument('--resume', action='store_true', dest='resume')
    argparser.add_argument('--resume_dir', type=str)
    argparser.add_argument('--resume_epoch', type=int)

    # save
    argparser.add_argument('--exp_dir', type=str, required=True)

    # model
    argparser.add_argument('--emb_dim', type=int, default=128)
    argparser.add_argument('--hidden_dim', type=int, default=256)
    argparser.add_argument('--dropout_rate', '-drop', type=float, default=0.5)
    argparser.add_argument('--n_layers', type=int, default=1)
    argparser.add_argument('--response_max_len', type=int, default=15)

    # data
    argparser.add_argument('--train_query_file',
                           '-tqf',
                           type=str,
                           required=True)
    argparser.add_argument('--train_response_file',
                           '-trf',
                           type=str,
                           required=True)
    argparser.add_argument('--valid_query_file',
                           '-vqf',
                           type=str,
                           required=True)
    argparser.add_argument('--valid_response_file',
                           '-vrf',
                           type=str,
                           required=True)
    argparser.add_argument('--vocab_file', '-vf', type=str, default='')
    argparser.add_argument('--max_vocab_size', '-mv', type=int, default=100000)

    args = argparser.parse_args()

    # set up the output directory
    exp_dirname = os.path.join(args.exp_dir, args.mode,
                               time.strftime("%Y-%m-%d-%H-%M-%S"))
    os.makedirs(exp_dirname)

    # set up the logger
    tqdm_logging.config(logger,
                        os.path.join(exp_dirname, 'train.log'),
                        mode='w',
                        silent=False,
                        debug=True)

    if not args.vocab_file:
        logger.info("no vocabulary file")
        build_vocab(args.train_query_file,
                    args.train_response_file,
                    seperated=True)
        sys.exit()
    else:
        vocab, rev_vocab = load_vocab(args.vocab_file,
                                      max_vocab=args.max_vocab_size)

    vocab_size = len(vocab)

    word_embeddings = nn.Embedding(vocab_size,
                                   args.emb_dim,
                                   padding_idx=SYM_PAD)
    E = EncoderRNN(vocab_size,
                   args.emb_dim,
                   args.hidden_dim,
                   args.n_layers,
                   args.dropout_rate,
                   bidirectional=True,
                   variable_lengths=True)
    G = Generator(vocab_size,
                  args.response_max_len,
                  args.emb_dim,
                  2 * args.hidden_dim,
                  args.n_layers,
                  dropout_p=args.dropout_rate)

    if args.use_cuda:
        word_embeddings.cuda()
        E.cuda()
        G.cuda()

    loss_func = nn.NLLLoss(size_average=False)
    params = list(word_embeddings.parameters()) + list(E.parameters()) + list(
        G.parameters())
    opt = torch.optim.Adam(params, lr=args.g_learning_rate)

    logger.info('----------------------------------')
    logger.info('Pre-train a neural conversation model')
    logger.info('----------------------------------')

    logger.info('Args:')
    logger.info(str(args))

    logger.info('Vocabulary from ' + args.vocab_file)
    logger.info('vocabulary size: %d' % vocab_size)
    logger.info('Loading text data from ' + args.train_query_file + ' and ' +
                args.train_response_file)

    # resume training from other experiment
    if args.resume:
        assert args.resume_epoch >= 0, 'If resume training, please assign resume_epoch'
        reload_model(args.resume_dir, args.resume_epoch, word_embeddings, E, G)
        start_epoch = args.resume_epoch + 1
    else:
        start_epoch = 0

    # dump args
    with open(os.path.join(exp_dirname, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    for e in range(start_epoch, args.num_epoch):
        logger.info('---------------------training--------------------------')
        train_data_generator = batcher(args.batch_size, args.train_query_file,
                                       args.train_response_file)
        logger.info("Epoch: %d/%d" % (e, args.num_epoch))
        step = 0
        total_loss = 0.0
        total_valid_char = []
        cur_time = time.time()
        while True:
            try:
                post_sentences, response_sentences = train_data_generator.next(
                )
            except StopIteration:
                # save model
                save_model(exp_dirname, e, word_embeddings, E, G)
                # evaluation
                eval(args.valid_query_file, args.valid_response_file,
                     args.batch_size, word_embeddings, E, G, loss_func,
                     args.use_cuda, vocab, args.response_max_len)
                break

            post_ids = [sentence2id(sent, vocab) for sent in post_sentences]
            response_ids = [
                sentence2id(sent, vocab) for sent in response_sentences
            ]
            posts_var, posts_length = padding_inputs(post_ids, None)
            responses_var, responses_length = padding_inputs(
                response_ids, args.response_max_len)
            # sort by post length
            posts_length, perms_idx = posts_length.sort(0, descending=True)
            posts_var = posts_var[perms_idx]
            responses_var = responses_var[perms_idx]
            responses_length = responses_length[perms_idx]

            # 在sentence后面加eos
            references_var = torch.cat([
                responses_var,
                Variable(torch.zeros(responses_var.size(0), 1).long(),
                         requires_grad=False)
            ],
                                       dim=1)
            for idx, length in enumerate(responses_length):
                references_var[idx, length] = SYM_EOS

            # show case
            #for p, r, ref in zip(posts_var.data.numpy()[:10], responses_var.data.numpy()[:10], references_var.data.numpy()[:10]):
            #    print ''.join(id2sentence(p, rev_vocab))
            #    print ''.join(id2sentence(r, rev_vocab))
            #    print ''.join(id2sentence(ref, rev_vocab))
            #    print

            if args.use_cuda:
                posts_var = posts_var.cuda()
                responses_var = responses_var.cuda()
                references_var = references_var.cuda()

            embedded_post = word_embeddings(posts_var)
            embedded_response = word_embeddings(responses_var)

            _, dec_init_state = E(embedded_post,
                                  input_lengths=posts_length.numpy())
            log_softmax_outputs = G.supervise(
                embedded_response, dec_init_state,
                word_embeddings)  # [B, T, vocab_size]

            outputs = log_softmax_outputs.view(-1, vocab_size)
            mask_pos = mask(references_var).view(-1).unsqueeze(-1)
            masked_output = outputs * (mask_pos.expand_as(outputs))
            loss = loss_func(masked_output,
                             references_var.view(-1)) / (posts_var.size(0))

            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += loss * (posts_var.size(0))
            total_valid_char.append(mask_pos)

            if step % args.print_every == 0:
                total_loss_val = total_loss.cpu().data.numpy()[0]
                total_valid_char_val = torch.sum(
                    torch.cat(total_valid_char, dim=1)).cpu().data.numpy()[0]
                logger.info(
                    'Step %5d: (per word) training perplexity %.2f (%.1f iters/sec)'
                    % (step, math.exp(total_loss_val / total_valid_char_val),
                       args.print_every / (time.time() - cur_time)))
                total_loss = 0.0
                total_valid_char = []
                total_case_num = 0
                cur_time = time.time()
            step = step + 1
Ejemplo n.º 6
0
import tensorflow as tf

import tfutils.opt
import tfutils.modeling

import data
import model as mnist
import trainer

# Data
batcher = data.batcher()
x, y = batcher.placeholders()

# Model
logits = mnist.logits(x)  # logits

# Accuracy
y_hat = mnist.prediction(logits)  # predicted label
accuracy = tfutils.modeling.accuracy(y, y_hat)

# Optimization
step = tfutils.opt.global_step()
cost = mnist.cost(y, logits, regularize=True)
train_step = tf.train.MomentumOptimizer(0.0005, 0.9).minimize(
        cost, global_step=step)

model_vars = {
    'x': x,
    'y': y,
    'logits': logits,
    'y_hat': y_hat,
Ejemplo n.º 7
0
flags.DEFINE_string('checkpoint_dir', 'logs/checkpoints/',
                    """directory containing model.pbtxt, saver.pbtxt, parameter
                    checkpoints""")
flags.DEFINE_boolean('use_validation_data', True,
                     """whether to use validation data or training data""")

model = predict.load(FLAGS.checkpoint_dir)
accuracy = modeling.accuracy(model.label_node, model.out_node)

num_examples = data.NUM_TEST if FLAGS.use_validation_data else data.NUM_TRAIN
steps = num_examples / data.BATCH_SIZE + 1

data_str = "test" if FLAGS.use_validation_data else "training"
print "Running " + data_str + " data"

with data.batcher() as batcher:
    with tf.Session() as sess:
        model.restore(sess)
        if FLAGS.use_validation_data:
            next_data = batcher.next_validation_batch
        else:
            next_data = batcher.next_training_batch

        accs = []
        for _ in xrange(steps):
            batch_x, batch_y = next_data()
            acc = sess.run(accuracy, feed_dict={model.in_node: batch_x,
                                                model.label_node: batch_y})
            accs.append(acc)
        print "Error: %.2f" % (100 - (sum(accs) / len(accs)) * 100)