Exemplo n.º 1
0
def chat(args):
    with tf.Session() as sess:
        _, _, vocab_path = data_utils.prepare_diologue(args.works_dir, args.vocab_size)
        print ()
        print ('-------loading model-------')
        args.batch_size = 1
        model = v_autoencoder(sess, args, feed_previous=True)
        sess.run(tf.global_variables_initializer())
        ckpt = tf.train.get_checkpoint_state(args.cyc_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Reading model parameters from %s @ %s" % (ckpt.model_checkpoint_path, datetime.now()))
            model.trans_saver.restore(sess, ckpt.model_checkpoint_path)
            print("Model reloaded @ %s" % (datetime.now()))
        ckpt = tf.train.get_checkpoint_state(args.class_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Reading model parameters from %s @ %s" % (ckpt.model_checkpoint_path, datetime.now()))
            model.class_saver.restore(sess, ckpt.model_checkpoint_path)
            print("Model reloaded @ %s" % (datetime.now()))
        ckpt = tf.train.get_checkpoint_state(args.model_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Reading model parameters from %s @ %s" % (ckpt.model_checkpoint_path, datetime.now()))
            model.saver.restore(sess, ckpt.model_checkpoint_path)
            print("Model reloaded @ %s" % (datetime.now()))

        vocab, rev_vocab = data_utils.initialize_vocab(vocab_path)

        sys.stdout.write("Input:     ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()

        while sentence:
            get_sentence(sess, sentence, vocab, rev_vocab, model)
            sys.stdout.write("Input:     ")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
Exemplo n.º 2
0
def step1(args): # training variational autoencoder
    train_id_path, dev_id_path, vocab_path = data_utils.prepare_diologue(args.works_dir, args.vocab_size)

    checkpoint_step = 200
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        print ()
        print ('-------loading model-------')
        model = v_autoencoder(sess, args, feed_previous=False)
        ckpt = tf.train.get_checkpoint_state(args.model_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Reading model parameters from %s @ %s" % (ckpt.model_checkpoint_path, datetime.now()))
            sess.run(tf.global_variables_initializer())
            model.saver.restore(sess, ckpt.model_checkpoint_path)
            print("Model reloaded @ %s" % (datetime.now()))
        else:
            print("Created model with fresh parameters.")
            sess.run(tf.global_variables_initializer())

        print ('loading dictionary...')
        vocab, rev_vocab = data_utils.initialize_vocab(vocab_path)

        print ('loading dev set form %s...' %dev_id_path)
        dev_set = data_utils.read_data(dev_id_path)
        print ('loading train set form %s...' %train_id_path)
        train_set = data_utils.read_data(train_id_path)
        print ('-------start training-------')
        seq_loss, kl_loss = 0.0, 0.0
        current_step = 0
        previous_losses = []

        while True:
            encoder_inputs, decoder_inputs, weights = model.train_get_batch(train_set)
            step_seq_loss, step_kl_loss = model.ae_step(sess, encoder_inputs, decoder_inputs, weights, False)
            sess.run(model.kl_weight_op)
            sess.run(model.sample_rate_op)
            seq_loss += step_seq_loss / checkpoint_step
            kl_loss  += step_kl_loss / checkpoint_step
            current_step += 1
            if current_step % checkpoint_step == 0:
                print ("global step %d seq_loss %.4f kl_loss %.4f @ %s" %(model.global_step.eval(), math.exp(seq_loss), kl_loss, datetime.now()))

                checkpoint_path = os.path.join(args.model_dir, "model.ckpt")
                model.saver.save(sess, checkpoint_path )
                seq_loss, kl_loss = 0.0, 0.0
                encoder_inputs, decoder_inputs, weights = model.train_get_batch(dev_set)
                step_seq_loss, step_kl_loss = model.ae_step(sess, encoder_inputs, decoder_inputs, weights, True)
                print ("  eval: seq_loss %.2f" %(math.exp(step_seq_loss)))
                sys.stdout.flush()
Exemplo n.º 3
0
def test(args):
    train_id_path, dev_id_path, vocab_path = data_utils.prepare_diologue(
        args.works_dir, args.vocab_size)
    pos_id_path, neg_id_path, vocab_path = data_utils.prepare_cyc_diologue(
        args.works_dir, args.vocab_size)

    checkpoint_step = 200
    args.batch_size = 1
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        print()
        print('-------loading model-------')
        model = v_autoencoder(sess, args, feed_previous=True)
        sess.run(tf.global_variables_initializer())
        ckpt = tf.train.get_checkpoint_state(args.cyc_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Reading model parameters from %s @ %s" %
                  (ckpt.model_checkpoint_path, datetime.now()))
            model.trans_saver.restore(sess, ckpt.model_checkpoint_path)
            print("Model reloaded @ %s" % (datetime.now()))
        ckpt = tf.train.get_checkpoint_state(args.class_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Reading model parameters from %s @ %s" %
                  (ckpt.model_checkpoint_path, datetime.now()))
            model.class_saver.restore(sess, ckpt.model_checkpoint_path)
            print("Model reloaded @ %s" % (datetime.now()))
        ckpt = tf.train.get_checkpoint_state(args.model_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Reading model parameters from %s @ %s" %
                  (ckpt.model_checkpoint_path, datetime.now()))
            model.saver.restore(sess, ckpt.model_checkpoint_path)
            print("Model reloaded @ %s" % (datetime.now()))

        vocab, rev_vocab = data_utils.initialize_vocab(vocab_path)

        with open(os.path.join(args.works_dir, 'seq2seq.txt'), 'r') as fin:
            with open(os.path.join(args.works_dir, 'neg2pos.txt'),
                      'w') as fout:
                for l, seq in enumerate(fin):
                    print(('\r%d' % l), end='')
                    sentence = seq.strip()
                    out = get_sentence(sess, sentence, vocab, rev_vocab, model)
                    fout.write(out + '\n')
Exemplo n.º 4
0
    def chat(self):
        print ('---------prepare data---------')
        _, _, vocab_path = data_utils.prepare_diologue(self.works_dir, self.data_path, self.vocab_size)
        print ('loading dictionary...')
        vocab, rev_vocab = data_utils.initialize_vocab(vocab_path)
        print ('---------building model---------')
        self.sess.run(tf.global_variables_initializer())
        ckpt = tf.train.get_checkpoint_state(self.works_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Reading model parameters from %s @ %s" % (ckpt.model_checkpoint_path, datetime.now()))
            self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            print("Model reloaded @ %s" % (datetime.now()))
        else:
            raise ValueError("can't find the path" )

        sentence = input('Input: ')
        pat = re.compile('(\W+)')

        while sentence:
            true_neg_id, fake_pos_id = self.test_sentence(self.sess, sentence, vocab, rev_vocab, pat)
            print ('fake_pos:', self.print_sentence(fake_pos_id, rev_vocab, 0))
            print ('true_neg:', self.print_sentence(true_neg_id, rev_vocab, 0))
            sentence = input('Input: ')
Exemplo n.º 5
0
    def train(self):
        print ('---------prepare data---------')
        pos_id_path, neg_id_path, vocab_path = data_utils.prepare_diologue(self.works_dir, self.data_path, self.vocab_size)
        print ('loading dictionary...')
        vocab, rev_vocab = data_utils.initialize_vocab(vocab_path)
        print ('loading pos set form %s...' %pos_id_path)
        pos_set = data_utils.read_data(pos_id_path, 1.0)
        print ('loading neg set from %s...' %neg_id_path)
        neg_set = data_utils.read_data(neg_id_path, 0.0)

        print ('---------building model---------')
        self.sess.run(tf.global_variables_initializer())
        ckpt = tf.train.get_checkpoint_state(self.works_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("Reading model parameters from %s @ %s" % (ckpt.model_checkpoint_path, datetime.now()))
            self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            print("Model reloaded @ %s" % (datetime.now()))
        else:
            print ('Creating new parameters @ %s'  % (datetime.now()))

        pos_seq_loss, neg_seq_loss = 0.0, 0.0

        if self.pretrain:
            print ('--------start pretraining-------')
            for i in range(2000):
                batch_encoder_inputs, batch_target_weights = self.get_batch(pos_set, neg_set)
                pos_sequence_loss, neg_sequence_loss = self.pretrain_step(self.sess, batch_encoder_inputs, batch_target_weights)

                pos_seq_loss += pos_sequence_loss / self.checkpoint_step
                neg_seq_loss += neg_sequence_loss / self.checkpoint_step

                if (i+1) % self.checkpoint_step == 0:
                    print ("iter: ", i, "pos_perplexity %.4f neg_d_loss %.4f @ %s" %(math.exp(pos_seq_loss), math.exp(neg_seq_loss), datetime.now()))
                    pos_seq_loss, neg_seq_loss = 0.0, 0.0

        g_loss, pos_d_loss, neg_d_loss, pos_seq_loss, neg_seq_loss = 0.0, 0.0, 0.0, 0.0, 0.0
        current_step = 0

        print ('---------start training---------')
        while True:
            if current_step == 200000:
                break
            batch_encoder_inputs, batch_target_weights = self.get_batch(pos_set, neg_set)
            gen_loss, pos_adv_loss, neg_adv_loss, pos_sequence_loss, neg_sequence_loss = self.g_step(self.sess, batch_encoder_inputs, batch_target_weights, train=True)
            pos_dis_loss, neg_dis_loss, pos_adv_loss, neg_adv_loss = self.d_step(self.sess, batch_encoder_inputs, train=True)

            g_loss += gen_loss / self.checkpoint_step
            pos_d_loss += pos_dis_loss / self.checkpoint_step
            neg_d_loss += neg_dis_loss / self.checkpoint_step
            pos_seq_loss += pos_sequence_loss / self.checkpoint_step
            neg_seq_loss += neg_sequence_loss / self.checkpoint_step

            # print (current_step)
            current_step += 1
            if current_step % self.checkpoint_step == 0:
                print ('global step', self.sess.run(self.global_step), end='')
                print (" g_loss %.4f pos_d_loss %.4f neg_d_loss %.4f " %(g_loss, pos_d_loss, neg_d_loss), end='')
                print ("pos_perplexity %.4f neg_d_loss %.4f @ %s" %(math.exp(pos_seq_loss), math.exp(neg_seq_loss), datetime.now()))
                g_loss, pos_d_loss, neg_d_loss, pos_seq_loss, neg_seq_loss = 0.0, 0.0, 0.0, 0.0, 0.0

            if current_step % (5*self.checkpoint_step) == 0:
                checkpoint_path = os.path.join(self.works_dir, "model.ckpt")
                self.saver.save(self.sess, checkpoint_path)
                true_pos_id, true_neg_id, fake_pos_id, fake_neg_id = self.get_sentence(self.sess, batch_encoder_inputs)
                print ('fake_pos:', self.print_sentence(fake_pos_id, rev_vocab, 0))
                print ('fake_neg:', self.print_sentence(fake_neg_id, rev_vocab, 0))
                print ('true_pos:', self.print_sentence(true_pos_id, rev_vocab, 0))
                print ('true_neg:', self.print_sentence(true_neg_id, rev_vocab, 0))