示例#1
0
    def train(self,
              sess,
              dataset,
              generate_session=None,
              is_train=True,
              ftest_name=FLAGS['agn_output_file'].value):
        st, ed, loss, acc, acc_1 = 0, 0, [], [], []
        if generate_session:
            dataset = dataset + generate_session
        print("Get %s data:len(dataset) is %d " %
              ("training" if is_train else "testing", len(dataset)))
        if not is_train:
            fout = open(ftest_name, "w")
            fout.close()
        while ed < len(dataset):
            st, ed = ed, ed + FLAGS['batch_size'].value if ed + \
                FLAGS['batch_size'].value < len(dataset) else len(dataset)
            batch_data = gen_batched_data(dataset[st:ed])
            outputs = self.step_decoder(
                sess, batch_data, forward_only=False if is_train else True)
            loss.append(outputs[0])
            predict_id = outputs[1]  # [batch_size, length, 10]

            tmp_acc, tmp_acc_1 = compute_acc(batch_data["aims"],
                                             predict_id,
                                             batch_data["rec_lists"],
                                             batch_data["rec_mask"],
                                             batch_data["purchase"],
                                             ftest_name=ftest_name,
                                             output=(not is_train))
            acc.append(tmp_acc)
            acc_1.append(tmp_acc_1)
        if is_train:
            sess.run(self.epoch_add_op)
        return np.mean(loss), np.mean(acc), np.mean(acc_1)
示例#2
0
 def pg_train(self, sess, dataset):
     st, ed, loss = 0, 0, []
     print("Get %s data:len(dataset) is %d " % ("training", len(dataset)))
     while ed < len(dataset):
         st, ed = ed, ed + FLAGS['batch_size'].value if ed + \
             FLAGS['batch_size'].value < len(dataset) else len(dataset)
         batch_data = gen_batched_data(dataset[st:ed])
         outputs = self.pg_step_decoder(sess,
                                        batch_data,
                                        forward_only=False)
         loss.append(outputs[0])
     sess.run(self.epoch_add_op)
     return np.mean(loss)
示例#3
0
    def train(self, data, data_gen, sess, dis_batch_size=32):
        st, ed, loss, acc = 0, 0, [], []
        while ed < len(data):
            st, ed = ed, ed + dis_batch_size if ed + dis_batch_size < len(
                data) else len(data)
            st_gen, ed_gen = st % len(data_gen), ed % len(data_gen)
            tmp_data_gen = data_gen[
                st_gen:ed_gen] if st_gen < ed_gen else data_gen[
                    st_gen:] + data_gen[:ed_gen]

            concat_data = list(data[st:ed]) + tmp_data_gen
            batch_data = gen_batched_data(concat_data)
            batch_data["labels"] = np.array(
                np.array([1] * (ed - st)).tolist() +
                np.array([0] * len(tmp_data_gen)).tolist())
            outputs = self.step_decoder(sess, batch_data)
            loss.append(outputs[0])
            acc.append(outputs[1])
        sess.run(self.epoch_add_op)
        return np.mean(loss), np.mean(acc)
示例#4
0
    def train(self,
              sess,
              dataset,
              is_train=True,
              ftest_name=FLAGS['env_output_file'].value):
        st, ed, loss, acc, acc_1, pr_loss, pu_loss = 0, 0, [], [], [], [], []
        tp, tn, fp, fn = [], [], [], []
        print("Get %s data:len(dataset) is %d " %
              ("training" if is_train else "testing", len(dataset)))
        if not is_train:
            fout = open(ftest_name, "w")
            fout.close()
        while ed < len(dataset):
            st, ed = ed, ed + FLAGS['batch_size'].value if ed + \
                FLAGS['batch_size'].value < len(dataset) else len(dataset)
            batch_data = gen_batched_data(dataset[st:ed])
            outputs = self.step_decoder(
                sess, batch_data, forward_only=False if is_train else True)
            loss.append(outputs[0])
            predict_index = outputs[1]  # [batch_size, length, 10]
            pr_loss.append(outputs[2])
            pu_loss.append(outputs[3])
            purchase_prob = outputs[4][:, :, 1]
            tmp_acc, tmp_acc_1 = compute_acc(batch_data["aims"],
                                             predict_index,
                                             batch_data["rec_lists"],
                                             batch_data["rec_mask"],
                                             batch_data["purchase"],
                                             ftest_name=ftest_name,
                                             output=(not is_train))
            acc.append(tmp_acc)
            acc_1.append(tmp_acc_1)

            if not FLAGS['use_simulated_data'].value:
                all_num, true_pos, true_neg, false_pos, false_neg = 1e-6, 0., 0., 0., 0.
                for b_pu, b_pu_l in zip(batch_data["purchase"], purchase_prob):
                    for pu, pu_l in zip(b_pu, b_pu_l):
                        if pu != -1.:
                            #print pu, pu_l
                            all_num += 1
                            if pu == 1. and pu_l > 0.5:
                                true_pos += 1
                            if pu == 1. and pu_l <= 0.5:
                                false_neg += 1
                            if pu == 0. and pu_l > 0.5:
                                false_pos += 1
                            if pu == 0. and pu_l <= 0.5:
                                true_neg += 1
                tp.append(true_pos / all_num)
                tn.append(true_neg / all_num)
                fp.append(false_pos / all_num)
                fn.append(false_neg / all_num)
        if not FLAGS['use_simulated_data'].value:
            print("Confusion matrix for purchase prediction:")
            print("true positive:%.4f" % np.mean(tp),
                  "true negative:%.4f" % np.mean(tn))
            print("false positive:%.4f" % np.mean(fp),
                  "false negative:%.4f" % np.mean(fn))
        print("predict:p@1:%.4f%%" % (np.mean(acc_1) * 100),
              "p@%d:%.4f%%" % (FLAGS['metric'].value, np.mean(acc) * 100))

        if is_train:
            sess.run(self.epoch_add_op)
        return np.mean(loss), np.mean(pr_loss), np.mean(pu_loss), np.mean(
            acc), np.mean(acc_1)
示例#5
0
def train(sess, dataset, is_train=True):
    def pro_acc(acc):
        final_acc = [[] for _ in range(FLAGS.n_class)]
        for ac in acc:
            for i in range(4):
                if np.sum(ac[i * 4:(i + 1) * 4]) == 1:
                    final_acc[i].append(ac[i * 4:(i + 1) * 4])
        for i in range(4):
            print(
                "final classification confusion matrix (%d-th category):" % i,
                np.mean(final_acc[i], 0).tolist())

    st, ed = 0, 0
    loss, loss_lm, acc = [], [], []
    while ed < len(dataset):
        if is_train:
            output_feed = [
                model_loss,
                gradient_norm,
                update,
            ]
        else:
            output_feed = [
                model_loss,
            ]
        st, ed = ed, ed + FLAGS.batch_size if ed + \
            FLAGS.batch_size < len(dataset) else len(dataset)
        if FLAGS.data_name == "multi_roc" or FLAGS.data_name == "roc":
            batch_data = gen_batched_data(dataset[st:ed])
            input_feed = {
                context: batch_data["story"],
                context_length: batch_data["story_length"],
                label: batch_data["label"],
            }
            if FLAGS.data_name == "multi_roc":
                output_feed.append(model_loss_lm)
                output_feed.append(model_acc_list)
        elif FLAGS.data_name == "kg":
            batch_data = gen_batched_data_from_kg(dataset[st:ed])
            input_feed = {
                context: batch_data["story"],
                context_length: batch_data["story_length"],
            }
        else:
            print("DATANAME ERROR")
        outputs = sess.run(output_feed, input_feed)
        loss.append(outputs[0])
        if FLAGS.data_name == "multi_roc":
            loss_lm.append(outputs[-2])
            acc.append(outputs[-1])
            if (st + 1) % 10000 == 0:
                print("current ppl:%.4f" % np.exp(np.mean(loss_lm)))
                pro_acc(acc)
                print("=" * 5)
    if is_train:
        sess.run(epoch_add_op)
    if FLAGS.data_name == "multi_roc":
        pro_acc(acc)
        return np.exp(np.mean(loss_lm))
    else:
        return np.exp(np.mean(loss))