Example #1
0
                model.claim: claim,
                model.warrant0_len: warrant0_len,
                model.warrant1_len: warrant1_len,
                model.reason_len: reason_len,
                model.claim_len: claim_len,
                model.input_keep_prob: 1,
                model.output_keep_prob: 1,
                model.keep_prob: 1
            }
            p_labels = sess.run(model.prob_labels, feed_dict)
            return p_labels

        print 'start train..........'
        batches = data_helper.iter_batch(batch_size, num_epochs,
                                         train_warrant0, train_warrant1,
                                         train_reason, train_claim,
                                         train_labels, train_warrant0_len,
                                         train_warrant1_len, train_reason_len,
                                         train_claim_len)
        train_acc_list = []
        train_loss_list = []
        max_dev_acc = float('-inf')
        min_dev_loss = float('inf')
        num_undesc = 0
        for current_epoch, batch in batches:
            if num_undesc > max_num_undsc:
                break
            warrant0, warrant1, reason, claim, labels, warrant0_len, warrant1_len, reason_len, claim_len = batch
            step, train_batch_loss, train_batch_acc = train_step(
                warrant0, warrant1, reason, claim, labels, warrant0_len,
                warrant1_len, reason_len, claim_len)
            current_step = tf.train.global_step(sess, global_step)
            feed_dict = {
                model.sen1_ids: sen1_ids,
                model.sen2_ids: sen2_ids,
                model.sen1_len: sen1_len,
                model.sen2_len: sen2_len,
                model.s_labels: s_labels,
                model.input_keep_prob: 1.0,
                model.output_keep_prob: 1.0,
            }
            step, loss, acc = sess.run(
                [global_step, model.s_loss, model.s_acc], feed_dict)
            return loss, acc

        print 'Start Source training........'
        batches = data_helper.iter_batch(s_batch_size, s_num_epochs,
                                         sen1_train, sen2_train,
                                         sen1_train_len, sen2_train_len,
                                         s_train_labels)
        s_train_loss_list = []
        s_train_acc_list = []
        s_max_dev_acc = float('-inf')
        s_min_dev_loss = 0
        s_num_undesc = 0
        for current_epoch, batch in batches:
            if s_num_undesc > s_max_num_undesc:
                break
            sen1_ids, sen2_ids, sen1_len, sen2_len, s_labels = batch
            step, s_train_batch_loss, s_train_batch_acc = s_train_step(
                sen1_ids, sen2_ids, sen1_len, sen2_len, s_labels)
            current_step = tf.train.global_step(sess, global_step)
            s_train_loss_list.append(s_train_batch_loss)
            s_train_acc_list.append(s_train_batch_acc)