def train(self, sess, X_train, y_train, X_val=None, y_val=None):
        saver = tf.train.Saver()

        num_iterations = int(math.ceil(1.0 * len(X_train)/self.batch_size))
        for epoch in range(self.num_epochs):# shuffle train in each epoch
            shuffle_index = np.arange(len(X_train))
            np.random.shuffle(shuffle_index)
            X_train = X_train[shuffle_index]
            y_train = y_train[shuffle_index]

            for iteration in range(num_iterations):
                X_train_batch, y_train_batch = helper.nextBatch(X_train, y_train, start_index=iteration * self.batch_size, batch_size=self.batch_size)
                X_train_entity1, X_train_entity2, X_train_features = np.transpose(np.array([X_train_batch[:,0]])), X_train_batch[:,0:1], X_train_batch[:,1:-1]
                #y_train_batch = np.transpose(np.array([y_train_batch]))#列表需要转换成矩阵类型运算
                _, logits , loss_train,train_summary = sess.run([self.optimizer, self.y_predict, self.cost,  self.train_summary], feed_dict = {self.input_entity1: X_train_entity1, self.input_entity2: X_train_entity2, self.features: X_train_features, self.targets: y_train_batch})
                predict_train = logits
                if iteration!= 0 and iteration % 100 == 0:
                    precision_train, recall_train, f1_train = self.evaluate(X_train_batch, y_train_batch, predict_train)
                    print('Iteration:%d\tprecision:%f\trecall:%f\tf1:%f\n'%(iteration,precision_train,recall_train,f1_train))
    def test(self, sess, X_test, y_test, output_path, id2label, doc_meta):

        num_iterations = int(math.ceil(1.0 * len(X_test) / self.batch_size))

        print("number of iteration: " + str(num_iterations))
        y_hat = []

        for iteration in range(num_iterations):

            X_test_batch1, X_test_batch2, y_test_batch = helper.nextBatch(
                X_test, y_test, iteration * self.batch_size, self.batch_size)

            pred = sess.run(
                [self.predictions],
                feed_dict={
                    self.input1: X_test_batch1,
                    self.input2: X_test_batch2,
                    self.targets: y_test_batch,
                })
            y_hat += list(pred[0])

        for i in range(len(doc_meta)):

            meta_data = doc_meta[i]

            dic = collections.OrderedDict()
            sense = id2label[y_hat[i]]

            dic["Arg1"] = {"TokenList": meta_data.arg_token_list[0]}
            dic["Arg2"] = {"TokenList": meta_data.arg_token_list[1]}
            dic["Connective"] = {
                "TokenList": meta_data.connective["TokenList"]
            }
            dic["DocID"] = meta_data.docid

            dic["Sense"] = [sense]
            dic["Type"] = meta_data.type

            with open(output_path, 'a', encoding="utf8") as outfile:
                outfile.write(json.dumps(dic))
                outfile.write('\n')
示例#3
0
    def train_an_iteration(self, sess, save_file, X_train, y_train, X_val, y_val, char2id, id2char, label2id, id2label, summary_writer_train_source, summary_writer_train_target, summary_writer_val_source, summary_writer_val_target, flag="source", is_summary=False, is_validation=False, X_test_source=None, y_test_source=None):
        
#        num_iterations = int(math.ceil(1.0 * len(X_train) / self.batch_size))
        num_iterations = 128       
        cnt = 0
        for iteration in range(num_iterations):
            print "this iteration is %d, the flag is %s"%(iteration, flag)
            # train, the flag indicate the data source
            X_train_batch, y_train_batch = helper.nextBatch(X_train, y_train, start_index=iteration * self.batch_size, batch_size=self.batch_size)
#            y_train_weight_batch = 1 + np.array((y_train_batch == label2id['B']) | (y_train_batch == label2id['E']), float)
            transition_batch = helper.getTransition(y_train_batch)
                
            _, loss_train, max_scores, max_scores_pre, length, train_summary =\
                sess.run([
                    self.optimizer_source if flag == "source" else self.optimizer_target,
                    self.loss_source if flag == "source" else self.loss_target,
                    self.max_scores_source if flag == "source" else self.max_scores_target,
                    self.max_scores_pre_source if flag == "source" else self.max_scores_pre_target,
                    self.length,
                    self.train_summary_source if flag == "source" else self.train_summary_target
                ], 
                feed_dict={
                    self.targets_transition:transition_batch, 
                    self.inputs:X_train_batch, 
                    self.targets:y_train_batch, 
 #                   self.targets_weight:y_train_weight_batch
                })
            print "the loss : %f"%loss_train
        if is_summary:
              presicion_loc, recall_loc, f_loc, presicion_org, recall_org, f_org, presicion_per, recall_per, f_per = self.predictBatch(sess, X_test_source, y_test_source, id2label, id2char, label2id, self.batch_size, "target")
              print "iteration: %5d, %s valid , valid precision: LOC %.5f, ORG %.5f, PER %.5f, valid recall: LOC %.5f, ORG %.5f, PER %.5f, valid f1: LOC %.5f, ORG %.5f, PER %.5f" % (iteration, flag, presicion_loc, presicion_org, presicion_per, recall_loc, recall_org, recall_per, f_loc, f_org, f_per)

              if f_loc + f_org + f_per >= self.max_f1:
                self.max_f1 = f_loc + f_org + f_per
                
                saver = tf.train.Saver()
                save_path = saver.save(sess, save_file)
                print "saved the best model with f1: %.5f" % (self.max_f1 / 3.0)

              self.last_f = f_loc + f_org + f_per
示例#4
0
    def train(self, sess, saver, save_file, X_train, y_train, X_valid, y_valid,
              X_train_tag, X_valid_tag, y_intent_train, y_intent_valid,
              model_dev, seq_len_train, seq_len_valid):

        char2id, id2char = helper.loadMap("meta_data/char2id")
        label2id, id2label = helper.loadMap("meta_data/label2id")

        num_iterations = int(math.ceil(1.0 * len(X_train) / self.batch_size))

        max_f1 = 0.0
        max_intent_acc = 0.0

        for epoch in range(self.num_epochs):
            # shuffle train in each epoch
            shuffle_index = np.arange(len(X_train))
            np.random.shuffle(shuffle_index)
            X_train = X_train[shuffle_index]
            seq_len_train = seq_len_train[shuffle_index]
            y_train = y_train[shuffle_index]
            X_train_tag = X_train_tag[shuffle_index]
            y_intent_train = y_intent_train[shuffle_index]
            print("current epoch: %d" % (epoch))
            for iteration in range(num_iterations):
                # train
                X_train_batch, y_train_batch, X_train_tag_batch, y_intent_train_batch, seq_len_batch_train = \
                    helper.nextBatch(X_train, y_train, X_train_tag, y_intent_train, seq_len_train, start_index=iteration * self.batch_size, batch_size=self.batch_size)

                y_train_weight_batch = 1 + np.array(
                    (y_train_batch == label2id['B']) |
                    (y_train_batch == label2id['E']) |
                    (y_train_batch == label2id['X']) |
                    (y_train_batch == label2id['Z']) |
                    (y_train_batch == label2id['U']) |
                    (y_train_batch == label2id['W']), float)
                transition_batch = helper.get_transition(y_train_batch)

                if self.crf_flag == 2:
                    _, loss_train, max_scores, max_scores_pre, predicts_train_intent, len_train  = \
                    sess.run([
                        self.optimizer,
                        self.sum_loss,
                        self.max_scores,
                        self.max_scores_pre,
                        self.intent_prediction,
                        self.sequence_len,
                    ],
                        feed_dict={
                            self.targets_transition:transition_batch,
                            self.inputs:X_train_batch,
                            self.slot_targets:y_train_batch,
                            self.targets_weight:y_train_weight_batch,
                            self.input_tag:X_train_tag_batch,
                            self.intent_target:y_intent_train_batch,
                         #   self.sequence_len:seq_len_batch_train
                        })

                    if iteration % 100 == 0:
                        predicts_train = self.viterbi(
                            max_scores,
                            max_scores_pre,
                            len_train,
                            predict_size=self.batch_size)
                        precision_train, recall_train, f1_train, acc_train = self.evaluate(
                            X_train_batch, y_train_batch, y_intent_train_batch,
                            predicts_train, predicts_train_intent, id2char,
                            id2label)
                        print(
                            "iteration, train loss, train precision, train recall, train f1, train acc",
                            iteration, loss_train, precision_train,
                            recall_train, f1_train, acc_train)
                elif self.crf_flag == 3:
                    _, transition_params_train, slot_train_logits, loss_train, predicts_train_intent, train_seq_length = \
                        sess.run([
                            self.optimizer,
                            self.transition_params,
                            self.slot_logits,
                            self.sum_loss,
                            self.intent_prediction,
                            self.sequence_len,
                        ],
                            feed_dict={
                                # self.targets_transition: transition_batch,
                                self.inputs: X_train_batch,
                                self.slot_targets: y_train_batch,
                                # self.targets_weight: y_train_weight_batch,
                                self.input_tag: X_train_tag_batch,
                                self.intent_target: y_intent_train_batch,
                         #       self.sequence_len: seq_len_batch_train
                            })
                    if iteration % 100 == 0:
                        label_list = []
                        for logit, seq_len in zip(slot_train_logits,
                                                  train_seq_length):
                            if seq_len == 0:  # padding 0 at last of the data
                                break
                            viterbi_seq, _ = viterbi_decode(
                                logit[:seq_len], transition_params_train)
                            label_list.append(viterbi_seq)
                        predicts_train = label_list
                        precision_train, recall_train, f1_train, acc_train = self.evaluate(
                            X_train_batch, y_train_batch, y_intent_train_batch,
                            predicts_train, predicts_train_intent, id2char,
                            id2label)
                        print(
                            "iteration, train loss, train precision, train recall, train f1, train acc",
                            iteration, loss_train, precision_train,
                            recall_train, f1_train, acc_train)

                # validation
                if iteration % 200 == 0:
                    f1_valid_sum = 0.0
                    acc_valid_sum = 0.0
                    loss_valid_sum = 0.0
                    precision_valid_sum = 0.0
                    recall_valid_sum = 0.0
                    num_iterations_valid = int(
                        math.ceil(1.0 * len(X_valid) / model_dev.batch_size))
                    for ttt in range(num_iterations_valid):
                        X_valid_batch, y_valid_batch, X_valid_input_tag_batch, y_intent_valid_batch, seq_len_valid_batch = \
                            helper.nextBatch(X_valid, y_valid, X_valid_tag, y_intent_valid, seq_len_valid, start_index=ttt * model_dev.batch_size, batch_size=model_dev.batch_size)

                        y_val_weight_batch = 1 + np.array(
                            (y_valid_batch == label2id['B']) |
                            (y_valid_batch == label2id['E']) |
                            (y_valid_batch == label2id['X']) |
                            (y_valid_batch == label2id['Z']) |
                            (y_valid_batch == label2id['U']) |
                            (y_valid_batch == label2id['W']), float)
                        transition_batch = helper.get_transition(y_valid_batch)
                        if self.crf_flag == 2:
                            loss_valid, max_scores, max_scores_pre, predicts_valid_intent, length_dev = \
                            sess.run([
                                model_dev.sum_loss,
                                model_dev.max_scores,
                                model_dev.max_scores_pre,
                                model_dev.intent_prediction,
                                model_dev.sequence_len,
                            ],
                                feed_dict={
                                    model_dev.targets_transition:transition_batch,
                                    model_dev.inputs:X_valid_batch,
                                    model_dev.slot_targets:y_valid_batch,
                                    model_dev.targets_weight:y_val_weight_batch,
                                    model_dev.input_tag:X_valid_input_tag_batch,
                                    model_dev.intent_target:y_intent_valid_batch,
                                    # model_dev.sequence_len:seq_len_valid_batch
                                })
                            predicts_valid = model_dev.viterbi(
                                max_scores,
                                max_scores_pre,
                                length_dev,
                                predict_size=model_dev.batch_size)
                        elif self.crf_flag == 3:
                            slot_train_logits, transition_params_train, length_dev, intent_prediction, loss_valid = \
                                sess.run([model_dev.slot_logits,
                                          model_dev.transition_params,
                                          model_dev.sequence_len,
                                          model_dev.intent_prediction,
                                          model_dev.sum_loss],
                                                                 feed_dict={
                                    # model_dev.targets_transition:transition_batch,
                                    model_dev.inputs:X_valid_batch,
                                    model_dev.slot_targets:y_valid_batch,
                                    # model_dev.targets_weight:y_val_weight_batch,
                                    model_dev.input_tag:X_valid_input_tag_batch,
                                    model_dev.intent_target:y_intent_valid_batch,
                                    # model_dev.sequence_len:seq_len_valid_batch
                                })
                            label_list = []
                            for logit, seq_len in zip(slot_train_logits,
                                                      length_dev):
                                if seq_len == 0:  # padding 0 at last of the data
                                    break
                                viterbi_seq, _ = viterbi_decode(
                                    logit[:seq_len], transition_params_train)
                                label_list.append(viterbi_seq)
                            predicts_valid = label_list
                            predicts_valid_intent = intent_prediction

                        precision_valid, recall_valid, f1_valid, acc_valid = \
                            model_dev.evaluate(X_valid_batch, y_valid_batch, y_intent_valid_batch, predicts_valid, predicts_valid_intent, id2char, id2label)

                        f1_valid_sum += f1_valid
                        acc_valid_sum += acc_valid
                        loss_valid_sum += loss_valid
                        precision_valid_sum += precision_valid
                        recall_valid_sum += recall_valid
                    if f1_valid_sum > max_f1:
                        max_f1 = f1_valid_sum
                        saver.save(sess, "predict_output/model")
                    if acc_valid_sum > max_intent_acc:
                        max_intent_acc = acc_valid_sum
                        # saver.save(sess, "predict_output/model")
                    print(
                        "iteration, valid loss, valid precision, valid recall, valid f1, valid acc",
                        iteration, loss_valid_sum / num_iterations_valid,
                        precision_valid_sum / num_iterations_valid,
                        recall_valid_sum / num_iterations_valid,
                        f1_valid_sum / num_iterations_valid,
                        acc_valid_sum / num_iterations_valid)

        print("max slot f1:", max_f1 / num_iterations_valid)
        print("max intent acc", max_intent_acc / num_iterations_valid)
示例#5
0
    def train(self, sess, save_file, X_train, y_train, X_val, y_val):
        saver = tf.train.Saver()

        char2id, id2char = helper.loadMap("char2id")
        label2id, id2label = helper.loadMap("label2id")

        merged = tf.contrib.deprecated.merge_all_summaries()
        summary_writer_train = tf.contrib.summary.SummaryWriter(
            'loss_log/train_loss', sess.graph)
        summary_writer_val = tf.contrib.summary.SummaryWriter(
            'loss_log/val_loss', sess.graph)

        num_iterations = int(math.ceil(1.0 * len(X_train) / self.batch_size))

        cnt = 0
        for epoch in range(self.num_epochs):
            # shuffle train in each epoch
            sh_index = np.arange(len(X_train))
            np.random.shuffle(sh_index)
            X_train = X_train[sh_index]
            y_train = y_train[sh_index]
            print("current epoch: %d" % (epoch))
            for iteration in range(num_iterations):
                # train
                X_train_batch, y_train_batch = helper.nextBatch(
                    X_train,
                    y_train,
                    start_index=iteration * self.batch_size,
                    batch_size=self.batch_size)
                y_train_weight_batch = 1 + np.array(
                    (y_train_batch == label2id['B']) |
                    (y_train_batch == label2id['E']), float)
                transition_batch = helper.getTransition(y_train_batch)

                _, loss_train, max_scores, max_scores_pre, length, train_summary = \
                 sess.run([
                  self.optimizer,
                  self.loss,
                  self.max_scores,
                  self.max_scores_pre,
                  self.length,
                  self.train_summary
                 ],
                  feed_dict={
                   self.targets_transition: transition_batch,
                   self.inputs: X_train_batch,
                   self.targets: y_train_batch,
                   self.targets_weight: y_train_weight_batch
                  })

                predicts_train = self.viterbi(max_scores,
                                              max_scores_pre,
                                              length,
                                              predict_size=self.batch_size)
                if iteration % 10 == 0:
                    cnt += 1
                    precision_train, recall_train, f1_train = self.evaluate(
                        X_train_batch, y_train_batch, predicts_train, id2char,
                        id2label)
                    summary_writer_train.add_summary(train_summary, cnt)
                    print(
                        "iteration: %5d, train loss: %5d, train precision: %.5f, train recall: %.5f, train f1: %.5f"
                        % (iteration, loss_train, precision_train,
                           recall_train, f1_train))

                # validation
                if iteration % 100 == 0:
                    X_val_batch, y_val_batch = helper.nextRandomBatch(
                        X_val, y_val, batch_size=self.batch_size)
                    y_val_weight_batch = 1 + np.array(
                        (y_val_batch == label2id['B']) |
                        (y_val_batch == label2id['E']), float)
                    transition_batch = helper.getTransition(y_val_batch)

                    loss_val, max_scores, max_scores_pre, length, val_summary = \
                     sess.run([
                      self.loss,
                      self.max_scores,
                      self.max_scores_pre,
                      self.length,
                      self.val_summary
                     ],
                      feed_dict={
                       self.targets_transition: transition_batch,
                       self.inputs: X_val_batch,
                       self.targets: y_val_batch,
                       self.targets_weight: y_val_weight_batch
                      })

                    predicts_val = self.viterbi(max_scores,
                                                max_scores_pre,
                                                length,
                                                predict_size=self.batch_size)
                    precision_val, recall_val, f1_val = self.evaluate(
                        X_val_batch, y_val_batch, predicts_val, id2char,
                        id2label)
                    summary_writer_val.add_summary(val_summary, cnt)
                    print(
                        "iteration: %5d, valid loss: %5d, valid precision: %.5f, valid recall: %.5f, valid f1: %.5f"
                        % (iteration, loss_val, precision_val, recall_val,
                           f1_val))

                    if f1_val > self.max_f1:
                        self.max_f1 = f1_val
                        save_path = saver.save(sess, save_file)
                        print("saved the best model with f1: %.5f" %
                              (self.max_f1))
示例#6
0
    def train(self, sess, save_file, train_data, val_data):
        saver = tf.train.Saver(max_to_keep=3)

        #train data
        X_train = train_data['char']
        X_left_train = train_data['left']
        X_right_train = train_data['right']
        X_pos_train = train_data['pos']
        X_lpos_train = train_data['lpos']
        X_rpos_train = train_data['rpos']
        X_rel_train = train_data['rel']
        X_dis_train = train_data['dis']
        y_train = train_data['label']

        #dev data
        X_val = val_data['char']
        X_left_val = val_data['left']
        X_right_val = val_data['right']
        X_pos_val = val_data['pos']
        X_lpos_val = val_data['lpos']
        X_rpos_val = val_data['rpos']
        X_rel_val = val_data['rel']
        X_dis_val = val_data['dis']
        y_val = val_data['label']

        #dictionary
        char2id, id2char = helper.loadMap("char2id")
        pos2id, id2pos = helper.loadMap("pos2id")
        label2id, id2label = helper.loadMap("label2id")

        merged = tf.summary.merge_all()
        summary_writer_train = tf.summary.FileWriter('loss_log/train_loss', sess.graph)  
        summary_writer_val = tf.summary.FileWriter('loss_log/val_loss', sess.graph)     
        
        num_iterations = int(math.ceil(1.0 * len(X_train) / self.batch_size))

        cnt = 0
        for epoch in range(self.num_epochs):
            # shuffle train in each epoch
            sh_index = np.arange(len(X_train))
            np.random.shuffle(sh_index)
            X_train = X_train[sh_index]
            X_left_train = X_left_train[sh_index]
            X_right_train = X_right_train[sh_index]
            X_pos_train = X_pos_train[sh_index]
            X_lpos_train = X_lpos_train[sh_index]
            X_rpos_train = X_rpos_train[sh_index]
            X_rel_train = X_rel_train[sh_index]
            X_dis_train = X_dis_train[sh_index]
            y_train = y_train[sh_index]

            train_data['char'] = X_train
            train_data['left'] = X_left_train
            train_data['right'] = X_right_train
            train_data['pos'] = X_pos_train
            train_data['lpos'] = X_lpos_train
            train_data['rpos'] = X_rpos_train
            train_data['rel'] = X_rel_train
            train_data['dis'] = X_dis_train
            train_data['label'] = y_train

            print "current epoch: %d" % (epoch)
            for iteration in range(num_iterations):
                # train 
                #get batch
                train_batches = helper.nextBatch(train_data, start_index=iteration * self.batch_size, batch_size=self.batch_size)
                X_train_batch = train_batches['char']
                X_left_train_batch = train_batches['left']
                X_right_train_batch = train_batches['right']
                X_pos_train_batch = train_batches['pos']
                X_lpos_train_batch = train_batches['lpos']
                X_rpos_train_batch = train_batches['rpos']
                X_rel_train_batch = train_batches['rel']
                X_dis_train_batch = train_batches['dis']
                y_train_batch = train_batches['label']
                
                # feed batch to model and run
                _, loss_train, length, train_summary, logits, trans_params =\
                    sess.run([
                        self.optimizer, 
                        self.loss, 
                        self.length,
                        self.train_summary,
                        self.logits,
                        self.trans_params,
                    ], 
                    feed_dict={
                        self.inputs:X_train_batch,
                        self.lefts:X_left_train_batch,
                        self.rights:X_right_train_batch,
                        self.poses:X_pos_train_batch,
                        self.lposes:X_lpos_train_batch,
                        self.rposes:X_rpos_train_batch,
                        self.rels:X_rel_train_batch,
                        self.dises:X_dis_train_batch,
                        self.targets:y_train_batch 
                        # self.targets_weight:y_train_weight_batch
                    })
                # print (len(length))

                #get predict f1
                predicts_train = self.viterbi(logits, trans_params, length, predict_size=self.batch_size)
                if iteration > 0 and iteration % 10 == 0:
                    cnt += 1
                    hit_num, pred_num, true_num = self.evaluate(y_train_batch, predicts_train, id2char, id2label)
                    precision_train, recall_train, f1_train = self.caculate(hit_num, pred_num, true_num)
                    summary_writer_train.add_summary(train_summary, cnt)
                    print "iteration: %5d/%5d, train loss: %5d, train precision: %.5f, train recall: %.5f, train f1: %.5f" % (iteration, num_iterations, loss_train, precision_train, recall_train, f1_train)  
                    
                # a batch in validation
                if iteration > 0 and iteration % 100 == 0:
                    val_batches = helper.nextRandomBatch(val_data, batch_size=self.batch_size)
                    
                    X_val_batch = val_batches['char']
                    X_left_val_batch = val_batches['left']
                    X_right_val_batch = val_batches['right']
                    X_pos_val_batch = val_batches['pos']
                    X_lpos_val_batch = val_batches['lpos']
                    X_rpos_val_batch = val_batches['rpos']
                    X_rel_val_batch = val_batches['rel']
                    X_dis_val_batch = val_batches['dis']
                    y_val_batch = val_batches['label']
                    
                    loss_val, length, val_summary, logits, trans_params =\
                        sess.run([
                            self.loss, 
                            self.length,
                            self.val_summary,
                            self.logits,
                            self.trans_params,
                        ], 
                        feed_dict={
                            self.inputs:X_val_batch,
                            self.lefts:X_left_val_batch,
                            self.rights:X_right_val_batch,
                            self.poses:X_pos_val_batch,
                            self.lposes:X_lpos_val_batch,
                            self.rposes:X_rpos_val_batch,
                            self.rels:X_rel_val_batch,
                            self.dises:X_dis_val_batch,
                            self.targets:y_val_batch 
                            # self.targets_weight:y_val_weight_batch
                        })
                    
                    predicts_val = self.viterbi(logits, trans_params, length, predict_size=self.batch_size)
                    hit_num, pred_num, true_num = self.evaluate(y_val_batch, predicts_val, id2char, id2label)
                    precision_val, recall_val, f1_val = self.caculate(hit_num, pred_num, true_num)
                    summary_writer_val.add_summary(val_summary, cnt)
                    print "iteration: %5d, valid loss: %5d, valid precision: %.5f, valid recall: %.5f, valid f1: %.5f" % (iteration, loss_val, precision_val, recall_val, f1_val)

                # calc f1 for the whole dev set
                if epoch > 0 and iteration == num_iterations -1:
                    num_val_iterations = int(math.ceil(1.0 * len(X_val) / self.batch_size))
                    preds_lines = []
                    for val_iteration in range(num_val_iterations):
                        val_batches = helper.nextBatch(val_data, start_index=val_iteration * self.batch_size, batch_size=self.batch_size)
                        X_val_batch = val_batches['char']
                        X_left_val_batch = val_batches['left']
                        X_right_val_batch = val_batches['right']
                        X_pos_val_batch = val_batches['pos']
                        X_lpos_val_batch = val_batches['lpos']
                        X_rpos_val_batch = val_batches['rpos']
                        X_rel_val_batch = val_batches['rel']
                        X_dis_val_batch = val_batches['dis']
                        y_val_batch = val_batches['label']

                        loss_val, length, val_summary, logits, trans_params =\
                            sess.run([
                                self.loss, 
                                self.length,
                                self.val_summary,
                                self.logits,
                                self.trans_params,
                            ], 
                            feed_dict={
                                self.inputs:X_val_batch,
                                self.lefts:X_left_val_batch,
                                self.rights:X_right_val_batch,
                                self.poses:X_pos_val_batch,
                                self.lposes:X_lpos_val_batch,
                                self.rposes:X_rpos_val_batch,
                                self.rels:X_rel_val_batch,
                                self.dises:X_dis_val_batch,
                                self.targets:y_val_batch 
                                # self.targets_weight:y_val_weight_batch
                            })
                    
                        predicts_val = self.viterbi(logits, trans_params, length, predict_size=self.batch_size)
                        preds_lines.extend(predicts_val)
                    preds_lines = preds_lines[:len(y_val)]
                    recall_val, precision_val, f1_val, errors = helper.calc_f1(preds_lines, id2label, 'cpbdev.txt', 'validation.out')
                    if f1_val > self.max_f1:
                        self.max_f1 = f1_val
                        save_path = saver.save(sess, save_file)
                        helper.calc_f1(preds_lines, id2label, 'cpbdev.txt', 'validation.out.best')
                        print "saved the best model with f1: %.5f" % (self.max_f1)
                    print "valid precision: %.5f, valid recall: %.5f, valid f1: %.5f, errors: %5d" % (precision_val, recall_val, f1_val, errors)
    def train(self, sess, save_file, X_train, y_train, X_val, y_val):
        saver = tf.train.Saver()

        char2id, id2char = helper.loadMap("char2id")
        label2id, id2label = helper.loadMap("label2id")

        merged = tf.merge_all_summaries()
        summary_writer_train = tf.train.SummaryWriter('loss_log/train_loss', sess.graph)  
        summary_writer_val = tf.train.SummaryWriter('loss_log/val_loss', sess.graph)     
        
        num_iterations = int(math.ceil(1.0 * len(X_train) / self.batch_size))

        cnt = 0
        for epoch in range(self.num_epochs):
            # shuffle train in each epoch
            sh_index = np.arange(len(X_train))
            np.random.shuffle(sh_index)
            X_train = X_train[sh_index]
            y_train = y_train[sh_index]
            print "current epoch: %d" % (epoch)
            for iteration in range(num_iterations):
                # train
                X_train_batch, y_train_batch = helper.nextBatch(X_train, y_train, start_index=iteration * self.batch_size, batch_size=self.batch_size)
                y_train_weight_batch = 1 + np.array((y_train_batch == label2id['B']) | (y_train_batch == label2id['E']), float)
                transition_batch = helper.getTransition(y_train_batch)
                
                _, loss_train, max_scores, max_scores_pre, length, train_summary =\
                    sess.run([
                        self.optimizer, 
                        self.loss, 
                        self.max_scores, 
                        self.max_scores_pre, 
                        self.length,
                        self.train_summary
                    ], 
                    feed_dict={
                        self.targets_transition:transition_batch, 
                        self.inputs:X_train_batch, 
                        self.targets:y_train_batch, 
                        self.targets_weight:y_train_weight_batch
                    })

                predicts_train = self.viterbi(max_scores, max_scores_pre, length, predict_size=self.batch_size)
                if iteration % 10 == 0:
                    cnt += 1
                    precision_train, recall_train, f1_train = self.evaluate(X_train_batch, y_train_batch, predicts_train, id2char, id2label)
                    summary_writer_train.add_summary(train_summary, cnt)
                    print "iteration: %5d, train loss: %5d, train precision: %.5f, train recall: %.5f, train f1: %.5f" % (iteration, loss_train, precision_train, recall_train, f1_train)  
                    
                # validation
                if iteration % 100 == 0:
                    X_val_batch, y_val_batch = helper.nextRandomBatch(X_val, y_val, batch_size=self.batch_size)
                    y_val_weight_batch = 1 + np.array((y_val_batch == label2id['B']) | (y_val_batch == label2id['E']), float)
                    transition_batch = helper.getTransition(y_val_batch)
                    
                    loss_val, max_scores, max_scores_pre, length, val_summary =\
                        sess.run([
                            self.loss, 
                            self.max_scores, 
                            self.max_scores_pre, 
                            self.length,
                            self.val_summary
                        ], 
                        feed_dict={
                            self.targets_transition:transition_batch, 
                            self.inputs:X_val_batch, 
                            self.targets:y_val_batch, 
                            self.targets_weight:y_val_weight_batch
                        })
                    
                    predicts_val = self.viterbi(max_scores, max_scores_pre, length, predict_size=self.batch_size)
                    precision_val, recall_val, f1_val = self.evaluate(X_val_batch, y_val_batch, predicts_val, id2char, id2label)
                    summary_writer_val.add_summary(val_summary, cnt)
                    print "iteration: %5d, valid loss: %5d, valid precision: %.5f, valid recall: %.5f, valid f1: %.5f" % (iteration, loss_val, precision_val, recall_val, f1_val)

                    if f1_val > self.max_f1:
                        self.max_f1 = f1_val
                        save_path = saver.save(sess, save_file)
                        print "saved the best model with f1: %.5f" % (self.max_f1)
    def train_an_iteration(self, sess, save_file, X_train, y_train, X_val, y_val, char2id, id2char, label2id, id2label, summary_writer_train_source, summary_writer_train_target, summary_writer_val_source, summary_writer_val_target, flag="source", is_summary=False, is_validation=False, X_test_source=None, y_test_source=None):
        saver = tf.train.Saver()
        
#        num_iterations = int(math.ceil(1.0 * len(X_train) / self.batch_size))
        num_iterations = 1       
        cnt = 0
        for iteration in range(num_iterations):
            print "this iteration is %d, the flag is %s"%(iteration, flag)
            # train, the flag indicate the data source
            X_train_batch, y_train_batch = helper.nextBatch(X_train, y_train, start_index=iteration * self.batch_size, batch_size=self.batch_size)
#            y_train_weight_batch = 1 + np.array((y_train_batch == label2id['B']) | (y_train_batch == label2id['E']), float)
            transition_batch = helper.getTransition(y_train_batch)
                
            _, loss_train, max_scores, max_scores_pre, length, train_summary =\
                sess.run([
                    self.optimizer_source if flag == "source" else self.optimizer_target,
                    self.loss_source if flag == "source" else self.loss_target,
                    self.max_scores_source if flag == "source" else self.max_scores_target,
                    self.max_scores_pre_source if flag == "source" else self.max_scores_pre_target,
                    self.length,
                    self.train_summary_source if flag == "source" else self.train_summary_target
                ], 
                feed_dict={
                    self.targets_transition:transition_batch, 
                    self.inputs:X_train_batch, 
                    self.targets:y_train_batch, 
 #                   self.targets_weight:y_train_weight_batch
                })

##            predicts_train = self.viterbi(max_scores, max_scores_pre, length, predict_size=self.batch_size)
##            if is_summary:
##                cnt += 1
##                presicion_loc, recall_loc, f_loc, presicion_org, recall_org, f_org, presicion_per, recall_per, f_per  = self.evaluate(X_train_batch, y_train_batch, predicts_train, id2char, id2label)
#                if flag == "source":
#                    summary_writer_train_source.add_summary(train_summary, cnt)
#                else:
#                    summary_writer_train_target.add_summary(train_summary, cnt)
                 
##                print "iteration: %5d, %s train loss: %5d, train precision: LOC %.5f, ORG %.5f, PER %.5f, train recall: LOC %.5f, ORG %.5f, PER %.5f, train f1: LOC %.5f, ORG %.5f, PER %.5f" % (iteration, flag, loss_train, presicion_loc, presicion_org, presicion_per, recall_loc, recall_org, recall_per, f_loc, f_org, f_per)
            # validation
    ##        if is_validation:
    ##            X_val_batch, y_val_batch = helper.nextRandomBatch(X_test_source, y_test_source, batch_size=self.batch_size)
#   ##             y_val_weight_batch = 1 + np.array((y_val_batch == label2id['B']) | (y_val_batch == label2id['E']), float)
    ##            transition_batch = helper.getTransition(y_val_batch)
    ##                
    ##            loss_val, max_scores, max_scores_pre, length, val_summary =\
    ##                sess.run([
    ##                    self.loss_source if flag == "source" else self.loss_target,
    ##                    self.max_scores_source if flag == "source" else self.max_scores_target,
    ##                   self.max_scores_pre_source if flag == "source" else self.max_scores_pre_target,
    ##                  self.length,
    ##                 self.train_summary_source if flag == "source" else self.train_summary_target
    ##                ],
    ##                feed_dict={
    ##                    self.targets_transition:transition_batch, 
    ##                    self.inputs:X_val_batch, 
    ##                    self.targets:y_val_batch, 
 #  ##                     self.targets_weight:y_val_weight_batch
    ##                })
     ##           
     ##           predicts_val = self.viterbi(max_scores, max_scores_pre, length, predict_size=self.batch_size)
     ##           presicion_loc, recall_loc, f_loc, presicion_org, recall_org, f_org, presicion_per, recall_per, f_per  = self.evaluate(X_val_batch, y_val_batch, predicts_val, id2char, id2label)
#               if flag == "source":
#                    summary_writer_val_source.add_summary(val_summary, cnt)
#                else:
#                   summary_writer_val_target.add_summary(val_summary, cnt)
            if is_summary:
              presicion_loc, recall_loc, f_loc, presicion_org, recall_org, f_org, presicion_per, recall_per, f_per = self.predictBatch(sess, X_test_source, y_test_source, id2label, id2char, label2id, self.batch_size, "target")
              print "iteration: %5d, %s valid , valid precision: LOC %.5f, ORG %.5f, PER %.5f, valid recall: LOC %.5f, ORG %.5f, PER %.5f, valid f1: LOC %.5f, ORG %.5f, PER %.5f" % (iteration, flag, presicion_loc, presicion_org, presicion_per, recall_loc, recall_org, recall_per, f_loc, f_org, f_per)

              if f_loc + f_org + f_per >= self.max_f1:
                self.max_f1 = f_loc + f_org + f_per
                save_path = saver.save(sess, save_file)
                print "saved the best model with f1: %.5f" % (self.max_f1 / 3.0)


#                print "********************************************************"
#                print "********************************************************"
#                X = helper.getPredict("./a", char2id)
#                saver = tf.train.Saver()
#                saver.restore(sess, model_path)
#                model.predictBatch(sess, X, y_true, id2label, id2char, label2id, 128, "source")
#                print "********************************************************"
#                print "********************************************************"
              self.last_f = f_loc + f_org + f_per
    def train(self, sess, save_file, X_train, y_train, X_val, y_val):

        saver = tf.train.Saver()

        summary_writer_train = tf.summary.FileWriter('loss_log/train_loss',
                                                     sess.graph)
        summary_writer_val = tf.summary.FileWriter('loss_log/val_loss',
                                                   sess.graph)

        num_iterations = int(math.ceil(1.0 * len(X_train) / self.batch_size))

        for epoch in range(self.num_epochs):

            # shuffle train in each epoch
            sh_index = np.arange(len(X_train))
            np.random.shuffle(sh_index)
            X_train = X_train[sh_index]
            y_train = y_train[sh_index]

            print("current epoch: %d" % (epoch))

            for iteration in range(num_iterations):
                # train
                X_train_batch1, X_train_batch2, y_train_batch = helper.nextBatch(
                    X_train, y_train, iteration * self.batch_size,
                    self.batch_size)

                _, train_loss, train_acc, train_summary = sess.run(
                    [
                        self.optimizer,
                        self.loss,
                        # self.predictions,
                        self.accuracy,
                        self.summary_op
                    ],
                    feed_dict={
                        self.input1: X_train_batch1,
                        self.input2: X_train_batch2,
                        self.targets: y_train_batch
                    })

                if iteration % 20 == 0:
                    # train_acc = helper.extractSense(y_train_batch, train_y_hat)
                    summary_writer_train.add_summary(train_summary, iteration)
                    print(
                        "iteration: %5d, train loss: %5d, train precision: %.5f"
                        % (iteration, train_loss, train_acc))

                # validation
                if iteration % 20 == 0:

                    X_val_batch1, X_val_batch2, y_val_batch = helper.nextRandomBatch(
                        X_val, y_val, self.batch_size)
                    dev_loss, dev_acc, val_summary = sess.run(
                        [
                            self.loss,
                            # self.predictions,
                            self.accuracy,
                            self.summary_op
                        ],
                        feed_dict={
                            self.input1: X_val_batch1,
                            self.input2: X_val_batch2,
                            self.targets: y_val_batch
                        })

                    # test_acc = helper.extractSense(y_val_batch, dev_y_hat)
                    summary_writer_val.add_summary(val_summary, iteration)
                    print(
                        "iteration: %5d, dev loss: %5d, dev precision: %.5f" %
                        (iteration, dev_loss, dev_acc))

                    if dev_acc > self.max_acc:
                        self.max_acc = dev_acc
                        saver.save(sess, save_file)
                        print("saved the best model with accuracy: %.5f" %
                              (self.max_acc))