Beispiel #1
0
    def run(self, sess, train_data_l, train_data_u, test_data, n_iter, keep_rate, save_dir, batch_size, alpha, FLAGS):
        self.init_global_step()
        with tf.name_scope('labeled'):
            with tf.variable_scope('classifier'):
                self.classifier_xa_l = self.classifier.create_placeholders('xa')
                self.classifier_y_l = self.classifier.create_placeholders('y')
                self.classifier_hyper_l = self.classifier.create_placeholders('hyper')
                logits_l = self.classifier.forward(self.classifier_xa_l, self.classifier_hyper_l)
                classifier_loss_l, classifier_acc_l, pri_loss_l = self.classifier.get_loss(logits_l, self.classifier_y_l, self.pri_prob_y)

            with tf.variable_scope('encoder'):
                self.encoder_xa_l = self.encoder.create_placeholders('xa')
                self.encoder_y_l = self.encoder.create_placeholders('y')
                self.encoder_hyper_l = self.encoder.create_placeholders('hyper')
                z_pst, z_pri, encoder_loss_l = self.encoder.forward(self.encoder_xa_l, self.encoder_y_l, self.encoder_hyper_l)

            with tf.variable_scope('decoder'):
                self.decoder_xa_l = self.decoder.create_placeholders('xa') #x is included since x is generated sequentially
                self.decoder_y_l = self.decoder.create_placeholders('y')
                self.decoder_hyper_l = self.decoder.create_placeholders('hyper')
                decoder_loss_l, ppl_fw_l, ppl_bw_l, ppl_l = self.decoder.forward(self.decoder_xa_l, self.decoder_y_l, z_pst, self.decoder_hyper_l)
            elbo_l = encoder_loss_l * self.klw + decoder_loss_l - pri_loss_l

        self.loss_l = elbo_l
        self.loss_c = classifier_loss_l
        
        with tf.name_scope('unlabeled'):
            with tf.variable_scope('classifier', reuse=True):
                self.classifier_xa_u = self.classifier.create_placeholders('xa')
                self.classifier_hyper_u = self.classifier.create_placeholders('hyper')
                logits_u = self.classifier.forward(self.classifier_xa_u, self.classifier_hyper_u)
                predict_u = tf.nn.softmax(logits_u)
                classifier_entropy_u = tf.losses.softmax_cross_entropy(predict_u, predict_u)

            encoder_loss_u, decoder_loss_u = [], []
            elbo_u = []
            self.encoder_xa_u = self.encoder.create_placeholders('xa')
            self.encoder_hyper_u = self.encoder.create_placeholders('hyper')
            self.decoder_xa_u = self.decoder.create_placeholders('xa')
            self.decoder_hyper_u = self.decoder.create_placeholders('hyper')
            batch_size = tf.shape(list(self.encoder_xa_u.values())[0])[0]
            for idx in range(self.n_class):
                with tf.variable_scope('encoder', reuse=True):
                    _label = tf.gather(tf.eye(self.n_class), idx)
                    _label = tf.tile(_label[None, :], [batch_size, 1])
                    _z_pst, _, _encoder_loss = self.encoder.forward(self.encoder_xa_u, {'y':_label}, self.encoder_hyper_u)
                    encoder_loss_u.append(_encoder_loss * self.klw)
                    _pri_loss_u = tf.log(tf.gather(self.pri_prob_y, idx))
            
                with tf.variable_scope('decoder', reuse=True):
                    _decoder_loss, _, _, _ = self.decoder.forward(self.decoder_xa_u, {'y':_label}, _z_pst, self.decoder_hyper_u)
                    decoder_loss_u.append(_decoder_loss)

                _elbo_u = _encoder_loss * self.klw + _decoder_loss# - _pri_loss_u
                elbo_u.append(_elbo_u)

        self.loss_u = tf.add_n([elbo_u[idx] * predict_u[:, idx] for idx in range(self.n_class)]) + classifier_entropy_u
        self.loss = tf.reduce_mean(self.loss_l + classifier_loss_l * alpha + self.loss_u)
        #self.loss = tf.reduce_mean(classifier_loss_l) 
        decoder_loss_l = tf.reduce_mean(decoder_loss_l)

        with tf.name_scope('train'):
            optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss, global_step=self.global_step) 
            #optimizer = self.training_op(self.loss, tf.trainable_variables(), self.grad_clip, 20, self.learning_rate)

        
        summary_kl = tf.summary.scalar('kl', tf.reduce_mean(encoder_loss_l))
        summary_loss = tf.summary.scalar('loss', self.loss)
        summary_loss_l = tf.summary.scalar('loss_l', tf.reduce_mean(self.loss_l))
        summary_loss_u = tf.summary.scalar('loss_u', tf.reduce_mean(self.loss_u))
        summary_acc = tf.summary.scalar('acc', classifier_acc_l)
        summary_ppl_fw = tf.summary.scalar('ppl_fw', ppl_fw_l)
        summary_ppl_bw = tf.summary.scalar('ppl_bw', ppl_bw_l)
        summary_ppl = tf.summary.scalar('ppl', ppl_l)
        train_summary_op = tf.summary.merge_all()

        test_acc = tf.placeholder(tf.float32, [])
        test_ppl = tf.placeholder(tf.float32, [])
        summary_acc_test = tf.summary.scalar('test_acc', test_acc)
        summary_ppl_test = tf.summary.scalar('test_ppl', test_ppl)
        test_summary_op = tf.summary.merge([summary_acc_test, summary_ppl_test])

        logger = ExpLogger('semi_tabsa', save_dir)
        logger.write_args(FLAGS)
        logger.write_variables(tf.trainable_variables())
        logger.file_copy(['*.py', 'encoder/*.py', 'decoder/*.py', 'classifier/*.py'])

        train_summary_writer = tf.summary.FileWriter(save_dir + '/train', sess.graph)
        test_summary_writer = tf.summary.FileWriter(save_dir + '/test', sess.graph)
        validate_summary_writer = tf.summary.FileWriter(save_dir + '/validate', sess.graph)

        sess.run(tf.global_variables_initializer())
    
        def get_batch(dataset):
            """ to get batch from an iterator, whenever the ending is reached. """
            while True:
                try:
                    batch = dataset.next()
                    break
                except:
                    pass
            return batch
        
        def get_feed_dict_help(plhs, data_dict, keep_rate, is_training):
            plh_dict = {}
            for plh in plhs: plh_dict.update(plh)
            data_dict.update({'keep_rate': keep_rate})
            data_dict.update({'is_training': is_training})
            feed_dict = self.get_feed_dict(plh_dict, data_dict)
            return feed_dict

        max_acc = 0.
        for i in range(n_iter):
            #for train, _ in self.get_batch_data(train_data, keep_rate):
            for samples, in train_data_l:
                feed_dict_clf_l = get_feed_dict_help(plhs=[self.classifier_xa_l, self.classifier_y_l, self.classifier_hyper_l],
                        data_dict=self.classifier.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)
                
                feed_dict_enc_l = get_feed_dict_help(plhs=[self.encoder_xa_l, self.encoder_y_l, self.encoder_hyper_l],
                        data_dict=self.encoder.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)
 
                feed_dict_dec_l = get_feed_dict_help(plhs=[self.decoder_xa_l, self.decoder_y_l, self.decoder_hyper_l],
                        data_dict=self.decoder.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)
                
                samples, = get_batch(train_data_u)
                feed_dict_clf_u = get_feed_dict_help(plhs=[self.classifier_xa_u, self.classifier_hyper_u],
                        data_dict=self.classifier.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)
                
                feed_dict_enc_u = get_feed_dict_help(plhs=[self.encoder_xa_u, self.encoder_hyper_u],
                        data_dict=self.encoder.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)
 
                feed_dict_dec_u = get_feed_dict_help(plhs=[self.decoder_xa_u, self.decoder_hyper_u],
                        data_dict=self.decoder.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)

                feed_dict = {}
                feed_dict.update(feed_dict_clf_l)
                feed_dict.update(feed_dict_enc_l)
                feed_dict.update(feed_dict_dec_l)
                feed_dict.update(feed_dict_clf_u)
                feed_dict.update(feed_dict_enc_u)
                feed_dict.update(feed_dict_dec_u)
                feed_dict.update({self.klw: 0.0001})

                _, _acc, _loss, _ppl, _step, summary = sess.run([optimizer, classifier_acc_l, decoder_loss_l, ppl_l, self.global_step, train_summary_op], feed_dict=feed_dict)
                #_, _acc, _step, summary = sess.run([optimizer, classifier_acc_l, self.global_step, train_summary_op], feed_dict=feed_dict)
                train_summary_writer.add_summary(summary, _step)
                #if np.random.rand() < 1/4:
                #    print(_acc, _loss, _ppl, _step)
                
                acc, ppl, loss, cnt = 0., 0., 0., 0
                for samples, in test_data:
                    feed_dict_clf_l = get_feed_dict_help(plhs=[self.classifier_xa_l, self.classifier_y_l, self.classifier_hyper_l],
                            data_dict=self.classifier.prepare_data(samples),
                            keep_rate=1.0,
                            is_training=False)
                    
                    feed_dict_enc_l = get_feed_dict_help(plhs=[self.encoder_xa_l, self.encoder_y_l, self.encoder_hyper_l],
                            data_dict=self.encoder.prepare_data(samples),
                            keep_rate=1.0,
                            is_training=False)
     
                    feed_dict_dec_l = get_feed_dict_help(plhs=[self.decoder_xa_l, self.decoder_y_l, self.decoder_hyper_l],
                            data_dict=self.decoder.prepare_data(samples),
                            keep_rate=1.0,
                            is_training=False)
    
                    feed_dict = {}
                    feed_dict.update(feed_dict_clf_l)
                    feed_dict.update(feed_dict_enc_l)
                    feed_dict.update(feed_dict_dec_l)
                    feed_dict.update({self.klw: 0})
    
                    _acc, _loss, _ppl, _step = sess.run([classifier_acc_l, decoder_loss_l, ppl_l, self.global_step], feed_dict=feed_dict)
                    acc += _acc * len(samples)
                    ppl += _ppl * len(samples)
                    loss += _loss * len(samples)
                    cnt += len(samples)
                #print(cnt)
                #print(acc)
                summary, _step = sess.run([test_summary_op, self.global_step], feed_dict={test_acc: acc/cnt, test_ppl: ppl/cnt})
                test_summary_writer.add_summary(summary, _step)
                logger.info('Iter {}: mini-batch loss={:.6f}, test acc={:.6f}'.format(_step, loss / cnt, acc / cnt))
                #print(save_dir)
                _dir="unlabel10k" 
                if acc / cnt > max_acc:
                    max_acc = acc / cnt
     
        logger.info('Optimization Finished! Max acc={}'.format(max_acc))

        logger.info('Learning_rate={}, iter_num={}, hidden_num={}, l2={}'.format(
            self.learning_rate,
            n_iter,
            self.n_hidden,
            self.l2_reg
        ))
def selftraining(sess, classifier, label_data, unlabel_data, test_data, FLAGS):
    xa_inputs = classifier.create_placeholders('xa')
    hyper_inputs = classifier.create_placeholders('hyper')
    y_inputs = classifier.create_placeholders('y')

    logits = classifier.forward(xa_inputs, hyper_inputs)
    loss, acc, _ = classifier.get_loss(logits, y_inputs,
                                       [0.0] * classifier.n_class)
    pred = tf.argmax(logits, axis=1)
    prob = tf.reduce_max(tf.nn.softmax(logits), axis=1)

    import time, datetime
    timestamp = str(int(time.time()))
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

    save_dir = FLAGS.save_dir + '/selftraining/' + str(
        timestamp) + '/' + __file__.split('.')[0]
    print(save_dir)
    logger = ExpLogger('semi_tabsa', save_dir)
    logger.write_args(vars(FLAGS)['__flags'])
    logger.write_variables(tf.trainable_variables())
    logger.file_copy(
        ['*.py', 'encoder/*.py', 'decoder/*.py', 'classifier/*.py'])

    def get_feed_dict_help(classifier, plhs, data_dict, keep_rate,
                           is_training):
        plh_dict = {}
        for plh in plhs:
            plh_dict.update(plh)
        data_dict.update({'keep_rate': keep_rate})
        data_dict.update({'is_training': is_training})
        feed_dict = classifier.get_feed_dict(plh_dict, data_dict)
        return feed_dict

    with tf.name_scope('train'):
        loss = tf.reduce_mean(loss)
        optimizer = classifier.training_op(loss,
                                           tf.trainable_variables(),
                                           FLAGS.grad_clip,
                                           20,
                                           FLAGS.learning_rate,
                                           grads=None,
                                           opt='Adam')

    NUM_SELECT = 1000
    NUM_ITER = 500
    best_acc_in_rounds, best_f1_in_rounds = [], []
    while len(unlabel_data):
        tf.global_variables_initializer().run()
        test_it = BatchIterator(len(test_data),
                                FLAGS.batch_size, [test_data],
                                testing=True)
        print(len(unlabel_data))

        selected = []
        new_unlabel = []
        it_cnt = 0
        best_acc, best_f1 = 0, 0
        while True:

            train_it = BatchIterator(len(label_data),
                                     FLAGS.batch_size, [label_data],
                                     testing=False)
            unlabel_it = BatchIterator(len(unlabel_data),
                                       FLAGS.batch_size, [unlabel_data],
                                       testing=True)

            for samples, in train_it:
                it_cnt += 1
                if it_cnt > NUM_ITER:
                    break
                feed_dict = get_feed_dict_help(
                    classifier,
                    plhs=[xa_inputs, y_inputs, hyper_inputs],
                    data_dict=classifier.prepare_data(samples),
                    keep_rate=FLAGS.keep_rate,
                    is_training=True)

                _, _loss, _acc, _step = sess.run(
                    [optimizer, loss, acc, classifier.global_step],
                    feed_dict=feed_dict)
                #print('Train: step {}, acc {}, loss {}'.format(it_cnt, _acc, _loss))

                ### proc test
                test_acc, cnt = 0, 0
                y_true = []
                y_pred = []
                for samples, in test_it:
                    data_dict = classifier.prepare_data(samples)
                    feed_dict = get_feed_dict_help(
                        classifier,
                        plhs=[xa_inputs, y_inputs, hyper_inputs],
                        data_dict=data_dict,
                        keep_rate=1.0,
                        is_training=False)

                    num = len(samples)
                    _acc, _loss, _pred, _step = sess.run(
                        [acc, loss, pred, classifier.global_step],
                        feed_dict=feed_dict)
                    y_pred.extend(list(_pred))
                    y_true.extend(list(np.argmax(data_dict['y'], 1)))
                    test_acc += _acc * num
                    cnt += num
                test_acc = test_acc / cnt
                test_f1 = f1_score(y_true, y_pred, average='macro')
                logger.info(
                    'Test: step {}, test acc={:.6f}, test f1={:.6f}'.format(
                        it_cnt, test_acc, test_f1))
                best_f1 = max(best_f1, test_f1)

                ### proc unlabel
                if best_acc < test_acc:
                    best_acc = test_acc
                    _unlabel = []
                    _preds = []
                    _probs = []
                    y_dict = {0: 'positive', 1: 'negative', 2: 'neutral'}

                    for samples, in unlabel_it:
                        feed_dict = get_feed_dict_help(
                            classifier,
                            plhs=[xa_inputs, hyper_inputs],
                            data_dict=classifier.prepare_data(samples),
                            keep_rate=1.0,
                            is_training=False)

                        _pred, _prob = sess.run([pred, prob],
                                                feed_dict=feed_dict)
                        _unlabel.extend(samples)
                        _preds.extend(list(_pred))
                        _probs.extend(list(_prob))

                    top_k_id = np.argsort(_probs)[::-1][:NUM_SELECT]
                    remain_id = np.argsort(_probs)[::-1][NUM_SELECT:]
                    selected = [_unlabel[idx] for idx in top_k_id]
                    preds = [_preds[idx] for idx in top_k_id]
                    for idx, sample in enumerate(selected):
                        sample['polarity'] = y_dict[preds[idx]]
                    new_unlabel = [_unlabel[idx] for idx in remain_id]

            if it_cnt > NUM_ITER:
                best_acc_in_rounds.append(best_acc)
                best_f1_in_rounds.append(best_f1)
                logger.info(str(best_acc_in_rounds) + str(best_f1_in_rounds))
                break

        label_data.extend(selected)
        unlabel_data = new_unlabel

    #print(max(best_acc_in_rounds), max(best_f1_in_rounds))
    logger.info(str(best_acc_in_rounds) + str(best_f1_in_rounds))
    def run(self, sess, train_data_l, train_data_u, test_data, n_iter, keep_rate, save_dir, batch_size, alpha, FLAGS):
        self.init_global_step()
        with tf.name_scope('labeled'):
            with tf.variable_scope('classifier'):
                self.classifier_xa_l = self.classifier.create_placeholders('xa')
                self.classifier_y_l = self.classifier.create_placeholders('y')
                self.classifier_hyper_l = self.classifier.create_placeholders('hyper')
                logits_l = self.classifier.forward(self.classifier_xa_l, self.classifier_hyper_l)
                classifier_loss_l, classifier_acc_l, pri_loss_l = self.classifier.get_loss(logits_l, self.classifier_y_l, self.pri_prob_y)
                pred_l = tf.argmax(logits_l, axis=1)

            with tf.variable_scope('encoder'):
                self.encoder_xa_l = self.encoder.create_placeholders('xa')
                self.encoder_y_l = self.encoder.create_placeholders('y')
                self.encoder_hyper_l = self.encoder.create_placeholders('hyper')
                z_pst, z_pri, encoder_loss_l = self.encoder.forward(self.encoder_xa_l, self.encoder_y_l, self.encoder_hyper_l)

            with tf.variable_scope('decoder'):
                self.decoder_xa_l = self.decoder.create_placeholders('xa') #x is included since x is generated sequentially
                self.decoder_y_l = self.decoder.create_placeholders('y')
                self.decoder_hyper_l = self.decoder.create_placeholders('hyper')
                decoder_loss_l, ppl_fw_l, ppl_bw_l, ppl_l = self.decoder.forward(self.decoder_xa_l, self.decoder_y_l, z_pst, self.decoder_hyper_l)
            elbo_l = encoder_loss_l * self.klw + decoder_loss_l - pri_loss_l

        self.loss_l = elbo_l
        self.loss_c = classifier_loss_l
        
        with tf.name_scope('unlabeled'):
            with tf.variable_scope('classifier', reuse=True):
                self.classifier_xa_u = self.classifier.create_placeholders('xa')
                self.classifier_hyper_u = self.classifier.create_placeholders('hyper')
                logits_u = self.classifier.forward(self.classifier_xa_u, self.classifier_hyper_u)
                predict_u = tf.nn.softmax(logits_u)
                classifier_entropy_u = tf.losses.softmax_cross_entropy(predict_u, predict_u)

            encoder_loss_u, decoder_loss_u = [], []
            elbo_u = []
            self.encoder_xa_u = self.encoder.create_placeholders('xa')
            self.encoder_hyper_u = self.encoder.create_placeholders('hyper')
            self.decoder_xa_u = self.decoder.create_placeholders('xa')
            self.decoder_hyper_u = self.decoder.create_placeholders('hyper')
            batch_size = tf.shape(list(self.encoder_xa_u.values())[0])[0]
            for idx in range(self.n_class):
                with tf.variable_scope('encoder', reuse=True):
                    _label = tf.gather(tf.eye(self.n_class), idx)
                    _label = tf.tile(_label[None, :], [batch_size, 1])
                    _z_pst, _, _encoder_loss = self.encoder.forward(self.encoder_xa_u, {'y':_label}, self.encoder_hyper_u)
                    encoder_loss_u.append(_encoder_loss * self.klw)
                    _pri_loss_u = tf.log(tf.gather(self.pri_prob_y, idx))
            
                with tf.variable_scope('decoder', reuse=True):
                    _decoder_loss, _, _, _ = self.decoder.forward(self.decoder_xa_u, {'y':_label}, _z_pst, self.decoder_hyper_u)
                    decoder_loss_u.append(_decoder_loss)

                _elbo_u = _encoder_loss * self.klw + _decoder_loss# - _pri_loss_u
                elbo_u.append(_elbo_u)

        self.loss_u = tf.add_n([elbo_u[idx] * predict_u[:, idx] for idx in range(self.n_class)]) + classifier_entropy_u
        self.loss = tf.reduce_mean(self.loss_l + classifier_loss_l * alpha + self.loss_u)
        self.loss += sum(tf.losses.get_regularization_losses())

        batch_size_l = tf.shape(decoder_loss_l)[0]
        decoder_loss_l = tf.reduce_mean(decoder_loss_l)

        with tf.name_scope('train'):
            #optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(cost, global_step=self.global_step)
            optimizer = self.training_op(self.loss, tf.trainable_variables(), self.grad_clip, 20, self.learning_rate, opt='Adam')
        
        summary_kl = tf.summary.scalar('kl', tf.reduce_mean(encoder_loss_l))
        summary_loss = tf.summary.scalar('loss', self.loss)
        summary_loss_l = tf.summary.scalar('loss_l', tf.reduce_mean(self.loss_l))
        summary_loss_u = tf.summary.scalar('loss_u', tf.reduce_mean(self.loss_u))
        summary_acc = tf.summary.scalar('acc', classifier_acc_l)
        summary_ppl_fw = tf.summary.scalar('ppl_fw', ppl_fw_l)
        summary_ppl_bw = tf.summary.scalar('ppl_bw', ppl_bw_l)
        summary_ppl = tf.summary.scalar('ppl', ppl_l)
        train_summary_op = tf.summary.merge_all()

        test_acc = tf.placeholder(tf.float32, [])
        test_ppl = tf.placeholder(tf.float32, [])
        summary_acc_test = tf.summary.scalar('test_acc', test_acc)
        summary_ppl_test = tf.summary.scalar('test_ppl', test_ppl)
        test_summary_op = tf.summary.merge([summary_acc_test, summary_ppl_test])

        logger = ExpLogger('semi_tabsa', save_dir)
        logger.write_args(FLAGS)
        logger.write_variables(tf.trainable_variables())
        logger.file_copy(['semi_tabsa.py', 'encoder/*.py', 'decoder/*.py', 'classifier/*.py'])

        summary_writer = tf.summary.FileWriter(save_dir + '/', sess.graph)
        #test_summary_writer = tf.summary.FileWriter(save_dir + '/', sess.graph)
        #validate_summary_writer = tf.summary.FileWriter(save_dir + '/validate', sess.graph)

        sess.run(tf.global_variables_initializer())

        def get_batch(dataset):
            """ to get batch from an iterator, whenever the ending is reached. """
            while True:
                try:
                    batch = dataset.next()
                    break
                except:
                    pass
            return batch
        
        def get_feed_dict_help(plhs, data_dict, keep_rate, is_training):
            plh_dict = {}
            for plh in plhs: plh_dict.update(plh)
            data_dict.update({'keep_rate': keep_rate})
            data_dict.update({'is_training': is_training})
            feed_dict = self.get_feed_dict(plh_dict, data_dict)
            return feed_dict

        max_acc, max_f1 = 0., 0.
        for i in range(n_iter):
            #for train, _ in self.get_batch_data(train_data, keep_rate):
            for samples, in train_data_l:
                feed_dict_clf_l = get_feed_dict_help(plhs=[self.classifier_xa_l, self.classifier_y_l, self.classifier_hyper_l],
                        data_dict=self.classifier.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)
                
                feed_dict_enc_l = get_feed_dict_help(plhs=[self.encoder_xa_l, self.encoder_y_l, self.encoder_hyper_l],
                        data_dict=self.encoder.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)
 
                feed_dict_dec_l = get_feed_dict_help(plhs=[self.decoder_xa_l, self.decoder_y_l, self.decoder_hyper_l],
                        data_dict=self.decoder.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)
                
                samples, = get_batch(train_data_u)
                feed_dict_clf_u = get_feed_dict_help(plhs=[self.classifier_xa_u, self.classifier_hyper_u],
                        data_dict=self.classifier.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)
                
                feed_dict_enc_u = get_feed_dict_help(plhs=[self.encoder_xa_u, self.encoder_hyper_u],
                        data_dict=self.encoder.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)
 
                feed_dict_dec_u = get_feed_dict_help(plhs=[self.decoder_xa_u, self.decoder_hyper_u],
                        data_dict=self.decoder.prepare_data(samples),
                        keep_rate=keep_rate,
                        is_training=True)

                feed_dict = {}
                feed_dict.update(feed_dict_clf_l)
                feed_dict.update(feed_dict_enc_l)
                feed_dict.update(feed_dict_dec_l)
                feed_dict.update(feed_dict_clf_u)
                feed_dict.update(feed_dict_enc_u)
                feed_dict.update(feed_dict_dec_u)
                feed_dict.update({self.klw: 0.0001})

                _, _acc, _loss, _ppl, _step, summary = sess.run([optimizer, classifier_acc_l, decoder_loss_l, ppl_l, self.global_step, train_summary_op], feed_dict=feed_dict)
                summary_writer.add_summary(summary, _step)
                if np.random.rand() < 1/4:
                    print(_acc, _loss, _ppl, _step)
            
            truth, pred, acc, ppl, loss, cnt = [], [], 0., 0., 0., 0
            idx2y = {0:'positive', 1:'negative', 2:'neutral'}
            for samples, in test_data:
                feed_dict_clf_l = get_feed_dict_help(plhs=[self.classifier_xa_l, self.classifier_y_l, self.classifier_hyper_l],
                        data_dict=self.classifier.prepare_data(samples),
                        keep_rate=1.0,
                        is_training=False)
                
                feed_dict_enc_l = get_feed_dict_help(plhs=[self.encoder_xa_l, self.encoder_y_l, self.encoder_hyper_l],
                        data_dict=self.encoder.prepare_data(samples),
                        keep_rate=1.0,
                        is_training=False)
 
                feed_dict_dec_l = get_feed_dict_help(plhs=[self.decoder_xa_l, self.decoder_y_l, self.decoder_hyper_l],
                        data_dict=self.decoder.prepare_data(samples),
                        keep_rate=1.0,
                        is_training=False)

                feed_dict = {}
                feed_dict.update(feed_dict_clf_l)
                feed_dict.update(feed_dict_enc_l)
                feed_dict.update(feed_dict_dec_l)
                feed_dict.update({self.klw: 0})

                num, _pred, _acc, _loss, _ppl, _step = sess.run([batch_size_l, pred_l, classifier_acc_l, decoder_loss_l, ppl_l, self.global_step], feed_dict=feed_dict)
                pred.extend([idx2y[int(_)] for _ in _pred])
                truth.extend([sample['polarity'] for sample in samples])
                acc += _acc * num
                ppl += _ppl * num
                loss += _loss * num
                cnt += num
            #print(cnt)
            #print(acc)
            f1 = f1_score(truth, pred, average='macro') 
            summary, _step = sess.run([test_summary_op, self.global_step], feed_dict={test_acc: acc/cnt, test_ppl: ppl/cnt})
            summary_writer.add_summary(summary, _step)
            logger.info('Iter {}: mini-batch loss={:.6f}, test acc={:.6f}, f1={:.6f}'.format(_step, loss / cnt, acc / cnt, f1))
            print(save_dir)
            if acc / cnt > max_acc:
                max_acc = acc / cnt
                max_f1 = f1
                with open(os.path.join(save_dir, 'pred'), 'w') as f:
                    idx2y = {0:'positive', 1:'negative', 2:'neutral'}
                    for samples, in test_data:
                        feed_dict = get_feed_dict_help(plhs=[self.classifier_xa_l, self.classifier_y_l, self.classifier_hyper_l],
                            data_dict=self.classifier.prepare_data(samples),
                            keep_rate=1.0,
                            is_training=False)

                        _pred, = sess.run([pred_l], feed_dict=feed_dict)
                        for idx, sample in enumerate(samples):
                            f.write(idx2y[_pred[idx]])
                            f.write('\t')
                            f.write(sample['polarity'])
                            f.write('\t')
                            f.write(' '.join([t + ' ' + str(b) for (t, b) in zip(sample['tokens'], sample['tags'])]))
                            f.write('\n')


        logger.info('Optimization Finished! Max acc={} f1={}'.format(max_acc, max_f1))

        logger.info('Learning_rate={}, iter_num={}, hidden_num={}, l2={}'.format(
            self.learning_rate,
            n_iter,
            self.n_hidden,
            self.l2_reg
        ))