Example #1
0
def main():
    config_name = sys.argv[1]
    test_file = "./data/250w/fre_phrase.pair"
    brae_config = GBRAEConfig(config_name)
    gbrae_data_name = "data/250w/gbrae.data.250w.min.count.%d.pkl" % brae_config.min_count
    gbrae_dict_name = "data/250w/gbrae.dict.250w.min.count.%d.pkl" % brae_config.min_count
    gbrae_phrase_dict_name = "data/250w/gbrae.250w.phrase.text.dict.pkl"
    model_name = sys.argv[2]
    random_state_name = sys.argv[3]
    logger_name = sys.argv[4]
    pre_logger(logger_name)
    load_random_state(random_state_name)
    src_word_dict, tar_word_dict = load_gbrae_dict(gbrae_dict_name)
    brae = pre_model(src_word_dict, tar_word_dict, brae_config, verbose=True)
    brae.load_model(model_name)
    src_phrases, tar_phrases, src_tar_pair = load_gbrae_data(gbrae_data_name)
    with open(gbrae_phrase_dict_name, 'rb') as fin:
        src_phrase2id = pickle.load(fin)
        tar_phrase2id = pickle.load(fin)
    test_pair = load_sub_data_pair(test_file, src_phrase2id, tar_phrase2id)
    brae.test_kl(src_phrases, tar_phrases, test_pair, brae_config)
Example #2
0
def bi(embedding_name, word_dim, hidden_dims, batch_size, dropout, act,
       cv_num=5, iter_num=25):
    model_name = "wdim_%d_hdim_%s_batch_%d_dropout_%f_act_%s" % (word_dim, hidden_dims,
                                                                 batch_size, dropout, act)
    log_file_name = "%s.log" % model_name
    logger = pre_logger(log_file_name)
    word_idx = get_ccf_word_map("user_tag_query.2W.seg.utf8.TRAIN")
    data_x, data_y = read_ccf_file("user_tag_query.2W.seg.utf8.TRAIN", word_idx, add_unknown_word=False)
    labels_nums = [len(np.unique(data_y[:, task_index])) - 1 for task_index in xrange(data_y.shape[1])]
    max_dev_acc_list = list()
    # predict_result_list = list()
    for train_index, dev_index in make_cv_index(data_x.shape[0], cv_num):
        train_x = data_x[train_index]
        train_y = data_y[train_index]
        dev_x = data_x[dev_index]
        dev_y = data_y[dev_index]
        classifier = pre_classifier(word_idx, embedding_name, labels_nums, word_dim,
                                    hidden_dims, batch_size, dropout, act)
        max_dev_acc = classifier.train([train_x, train_y], [dev_x, dev_y], iter_num=iter_num)
        max_dev_acc_list.append(max_dev_acc)
    logger.info("Aver Dev Acc: %f" % np.mean(max_dev_acc_list))
Example #3
0
def main():
    train_test = sys.argv[1]
    if train_test not in ["train", "predict"]:
        sys.stderr("train or predict")
        exit(1)
    config_name = sys.argv[2]
    forced_decode_data = "data/brae.train.data"
    phrase_data_path = "data/phrase.list"
    src_count_path = "data/src.trans.data"
    tar_count_path = "data/tar.trans.data"
    brae_config = BRAEISOMAPConfig(config_name)
    train_name = "dim%d_lrec%f_lsem%f_ll2%f_alpha%f_beta%f_num%d_seed%d_batch%d_lr%f" % (brae_config.dim,
                                                                                         brae_config.weight_rec,
                                                                                         brae_config.weight_sem,
                                                                                         brae_config.weight_l2,
                                                                                         brae_config.alpha,
                                                                                         brae_config.beta,
                                                                                         brae_config.trans_num,
                                                                                         brae_config.random_seed,
                                                                                         brae_config.batch_size,
                                                                                         brae_config.optimizer.param["lr"])
    model_name = "model/%s" % train_name
    temp_model = model_name + ".temp"
    if train_test == "train":
        start_iter = int(sys.argv[3]) if len(sys.argv) > 3 else 0
        end_iter = int(sys.argv[4]) if len(sys.argv) > 4 else 25
        pre_logger("braeisomap_" + train_name)
        np.random.seed(brae_config.random_seed)
        if start_iter == 0:
            src_word_dict, tar_word_dict = read_phrase_pair_vocab(forced_decode_data)
            src_word_dict, tar_word_dict = add_trans_word_vocab(src_count_path, src_word_dict, tar_word_dict)
            tar_word_dict, src_word_dict = add_trans_word_vocab(tar_count_path, src_word_dict, tar_word_dict)
            src_word_dict = filter_vocab(src_word_dict, min_count=0)
            tar_word_dict = filter_vocab(tar_word_dict, min_count=0)
            src_phrases, tar_phrases, src_tar_pair = read_phrase_list(forced_decode_data, src_word_dict, tar_word_dict)

            src_phrases, tar_phrases = read_trans_list(src_count_path, src_phrases, tar_phrases,
                                                       src_word_dict, tar_word_dict)
            tar_phrases, src_phrases = read_trans_list(tar_count_path, tar_phrases, src_phrases,
                                                       tar_word_dict, src_word_dict)
            src_phrases = clean_text(src_phrases)
            tar_phrases = clean_text(tar_phrases)
            brae = pre_model(src_word_dict, tar_word_dict, brae_config, verbose=True)
            with open(temp_model, 'wb') as fout:
                pickle.dump(src_phrases, fout)
                pickle.dump(tar_phrases, fout)
                pickle.dump(src_tar_pair, fout)
                pickle.dump(brae, fout)
                pickle.dump(np.random.get_state(), fout)
            if end_iter == 1:
                exit(1)
        else:
            with open(temp_model, 'rb') as fin:
                src_phrases = pickle.load(fin)
                tar_phrases = pickle.load(fin)
                src_tar_pair = pickle.load(fin)
                brae = pickle.load(fin)
                np.random.set_state(pickle.load(fin))
        brae.train(src_phrases, tar_phrases, src_tar_pair, brae_config, model_name, start_iter, end_iter)
        brae.save_model("%s.model" % model_name)
    elif train_test == "predict":
        num_process = int(sys.argv[3]) if len(sys.argv) > 3 else 0
        brae_predict(phrase_data_path, train_name + ".pred", model_file="%s.model" % model_name,
                     bilinear=True, num_process=num_process)
    else:
        sys.stderr("train or predict")
        exit(1)
def main():
    config_name = sys.argv[1]
    forced_decode_data = "../gbrae/data/250w/tune_hyperparameter/tune.data"
    brae_config = BRAEConfig(config_name)
    train_data = "../gbrae/data/250w/tune_hyperparameter/train/tune.train"
    dev_data = "../gbrae/data/250w/tune_hyperparameter/dev/tune.dev"
    test_data = "../gbrae/data/250w/tune_hyperparameter/test/tune.test"
    train_name = "dim%d_lrec%f_lsem%f_ll2%f_alpha%f_seed%d_batch%d_min%d_lr%f" % (
        brae_config.dim,
        brae_config.weight_rec,
        brae_config.weight_sem,
        brae_config.weight_l2,
        brae_config.alpha,
        brae_config.random_seed,
        brae_config.batch_size,
        brae_config.min_count,
        brae_config.optimizer.param["lr"],
    )
    model_name = "model/%s" % train_name
    temp_model = model_name + ".temp"
    start_iter = int(sys.argv[3]) if len(sys.argv) > 3 else 0
    end_iter = int(sys.argv[4]) if len(sys.argv) > 4 else 26
    pre_logger("brae_" + train_name)
    np.random.seed(brae_config.random_seed)
    if start_iter == 0:
        print "Load Dict ..."
        en_embedding_name = "../gbrae/data/embedding/en.token.dim%d.bin" % brae_config.dim
        zh_embedding_name = "../gbrae/data/embedding/zh.token.dim%d.bin" % brae_config.dim
        tar_word_dict = WordEmbedding.load_word2vec_word_map(en_embedding_name,
                                                             binary=True,
                                                             oov=True)
        src_word_dict = WordEmbedding.load_word2vec_word_map(zh_embedding_name,
                                                             binary=True,
                                                             oov=True)
        print "Compiling Model ..."
        brae = pre_model(src_word_dict,
                         tar_word_dict,
                         brae_config,
                         verbose=True)
        print "Load All Data ..."
        src_phrases, tar_phrases, src_tar_pair = read_phrase_list(
            forced_decode_data, src_word_dict, tar_word_dict)
        src_train = [p[WORD_INDEX] for p in src_phrases]
        tar_train = [p[WORD_INDEX] for p in tar_phrases]
        print "Write Binary Data ..."
        with open(temp_model, 'wb') as fout:
            pickle.dump(src_train, fout)
            pickle.dump(tar_train, fout)
            pickle.dump(src_tar_pair, fout)
            pickle.dump(brae, fout)
            pickle.dump(np.random.get_state(), fout)
        if end_iter == 1:
            exit(1)
    else:
        with open(temp_model, 'rb') as fin:
            src_train = pickle.load(fin)
            tar_train = pickle.load(fin)
            src_tar_pair = pickle.load(fin)
            brae = pickle.load(fin)
            np.random.set_state(pickle.load(fin))
    src_phrase2id = dict()
    tar_phrase2id = dict()
    for phrase, i in zip(src_phrases, xrange(len(src_phrases))):
        src_phrase2id[phrase[TEXT_INDEX]] = i
    for phrase, i in zip(tar_phrases, xrange(len(tar_phrases))):
        tar_phrase2id[phrase[TEXT_INDEX]] = i
    train_pair = load_sub_data_pair(train_data, src_phrase2id, tar_phrase2id)
    dev_pair = load_sub_data_pair(dev_data, src_phrase2id, tar_phrase2id)
    test_pair = load_sub_data_pair(test_data, src_phrase2id, tar_phrase2id)
    brae.tune_hyper_parameter(src_train,
                              tar_train,
                              train_pair,
                              dev_pair,
                              test_pair,
                              brae_config,
                              model_name,
                              start_iter=start_iter,
                              end_iter=end_iter)
    brae.save_model("%s.tune.model" % model_name)
Example #5
0
                    default=False,
                    help='Bilinear Score')
parser.add_argument('-p',
                    '--process',
                    type=int,
                    default=1,
                    help='Multi-Process, default is 1')
parser.add_argument('-n',
                    '--normalize',
                    type=bool,
                    default=True,
                    help='Pre-Trained Model Path')
parser.add_argument('--neg', type=str, help='Neg Phrase File Name')
parser.add_argument('--neg_out', type=str, help='Neg Out File Name')
args = parser.parse_args()
pre_logger("score_margin_exp_" + args.model.split(os.sep)[-1])
logger = logging.getLogger(__name__)

brae_predict(phrase_file=args.file,
             output_file=args.output,
             model_file=args.model,
             normalize=args.normalize,
             num_process=args.process,
             bilinear=args.bilinear)

if args.neg is not None:
    if args.neg_out is None:
        raise IOError("args.neg_out is None")
    brae_predict(phrase_file=args.neg,
                 output_file=args.neg_out,
                 model_file=args.model,
Example #6
0

if __name__ == "__main__":
    train_file = "train.txt"
    dev_file = "dev.txt"
    test_file = "test.txt"

    parser = argparse.ArgumentParser()
    parser.add_argument('--entity', type=int, default=50, help='Entity Dim')
    parser.add_argument('--seed', type=int, default=1993, help='Random Seed')
    parser.add_argument('--batch', type=int, default=50, help='Batch Size')
    parser.add_argument('--iter', type=int, default=500, help='Max Iter Number')
    parser.add_argument('--hidden', type=int, default=50, help='Hidden Dim')
    parser.add_argument('--scale', type=float, default=0.1, help='Uniform Initializer Scale')
    parser.add_argument('--negative', type=int, default=10, help='Negative Number')
    # parser.add_argument('--lr', type=float, default=0.01, help='Learning Rate')
    args = parser.parse_args()
    log_file_name = "slm_e%s_h%s_b%s_n%d_seed%d_i" \
                    "%s_scale%s" % (args.entity, args.hidden, args.batch, args.negative, args.seed,
                                                                args.iter, args.scale)

    pre_logger(log_file_name)
    np.random.seed(args.seed)
    entity_dict, relation_dict = data_to_indexs([train_file, dev_file, test_file])
    train_data = read_train_data_file(train_file, entity_dict, relation_dict)
    dev_data = read_eval_data_file(dev_file, entity_dict, relation_dict)
    test_data = read_eval_data_file(test_file, entity_dict, relation_dict)
    trainer = ReasonTrainer(entity_dict, relation_dict, entity_dim=args.entity, k=args.hidden,
                            initializer=UniformInitializer(scale=args.scale))
    trainer.train_relation(train_data, dev_data, test_data, max_iter=args.iter, C=args.negative, batch_size=args.batch)
Example #7
0
def main():
    config_name = sys.argv[1]
    train_data = "data/250w/tune_hyperparameter/train/tune.train"
    dev_data = "data/250w/tune_hyperparameter/dev/tune.dev"
    test_data = "data/250w/tune_hyperparameter/test/tune.test"
    brae_config = GBRAEConfig(config_name)
    gbrae_data_name = "data/250w/tune_hyperparameter/gbrae.data.tune.min.count.%d.pkl" % brae_config.min_count
    gbrae_dict_name = "data/250w/tune_hyperparameter/gbrae.dict.tune.min.count.%d.pkl" % brae_config.min_count
    gbrae_phrase_dict_name = "data/250w/tune_hyperparameter/gbrae.tune.phrase.text.dict.pkl"
    train_name = "dim%d_lrec%f_lsem%f_lword%f_alpha%f_beta%f_gama%f_num%d_seed%d_batch%d_min%d_lr%f" % (
        brae_config.dim, brae_config.weight_rec, brae_config.weight_sem,
        brae_config.weight_l2, brae_config.alpha, brae_config.beta,
        brae_config.gama, brae_config.trans_num, brae_config.random_seed,
        brae_config.batch_size, brae_config.min_count,
        brae_config.optimizer.param["lr"])
    model_name = "model/%s" % "gbrae_tune_hyper_" + train_name
    pre_train_model_file_name = None
    temp_model = model_name + ".temp"
    start_iter = int(sys.argv[2]) if len(sys.argv) > 3 else 0
    end_iter = int(sys.argv[3]) if len(sys.argv) > 4 else 26
    if len(sys.argv) > 5:
        pre_train_model_file_name = sys.argv[5]
        model_name += "_pred_%s" % pre_train_model_file_name
    pre_logger("gbrae_tune_hyper_" + train_name)
    np.random.seed(brae_config.random_seed)
    if start_iter == 0:
        print "Load Dict ..."
        src_word_dict, tar_word_dict = load_gbrae_dict(gbrae_dict_name)
        print "Compiling Model ..."
        brae = pre_model(src_word_dict,
                         tar_word_dict,
                         brae_config,
                         verbose=True)
        if pre_train_model_file_name is not None:
            brae.load_model(pre_train_model_file_name)
        print "Write Binary Data ..."
        with open(temp_model, 'wb') as fout:
            pickle.dump(brae, fout)
            pickle.dump(np.random.get_state(), fout)
        if end_iter == 1:
            exit(1)
    else:
        with open(temp_model, 'rb') as fin:
            brae = pickle.load(fin)
            np.random.set_state(pickle.load(fin))
    src_phrases, tar_phrases, src_tar_pair = load_gbrae_data(gbrae_data_name)
    if has_zero_para(src_phrases):
        print "src has zero para"
    else:
        print "src has not zero para"
    if has_zero_para(tar_phrases):
        print "tar phrases has zero para"
    else:
        print("tar has not zero para")
    with open(gbrae_phrase_dict_name, 'rb') as fin:
        src_phrase2id = pickle.load(fin)
        tar_phrase2id = pickle.load(fin)
    train_pair = load_sub_data_pair(train_data, src_phrase2id, tar_phrase2id)
    dev_pair = load_sub_data_pair(dev_data, src_phrase2id, tar_phrase2id)
    test_pair = load_sub_data_pair(test_data, src_phrase2id, tar_phrase2id)
    brae.tune_hyper_parameter(src_phrases,
                              tar_phrases,
                              train_pair,
                              dev_pair,
                              test_pair,
                              brae_config,
                              model_name,
                              start_iter=start_iter,
                              end_iter=end_iter)
    brae.save_model("%s.tune.model" % model_name)
Example #8
0
def main(_):
    phrase_file = FLAGS.phrase_file
    src_para_file = FLAGS.src_para
    tar_para_file = FLAGS.tar_para
    trans_file = FLAGS.trans_file
    src_phrase_list, tar_phrase_list, bi_phrase_list, src_word_idx, tar_word_idx = prepare_data(
        phrase_file, src_para_file, tar_para_file, trans_file)
    ssbrae_config = SSBRAEConfig(FLAGS.config_name)
    src_word_embedding = WordEmbedding(src_word_idx,
                                       dim=50,
                                       name="src_word_embedding")
    tar_word_embedding = WordEmbedding(tar_word_idx,
                                       dim=50,
                                       name="tar_word_embedding")
    sess = tf.Session()

    ssbrae_encoder = SSBRAEEncoder(
        src_word_embedding, tar_word_embedding, ssbrae_config.activation,
        ssbrae_config.normalize, ssbrae_config.weight_rec,
        ssbrae_config.weight_sem, ssbrae_config.weight_embedding,
        ssbrae_config.alpha, ssbrae_config.beta, ssbrae_config.max_src_len,
        ssbrae_config.max_tar_len, ssbrae_config.n_epoch,
        ssbrae_config.batch_size, ssbrae_config.dropout,
        ssbrae_config.optimizer_config, ssbrae_config.para,
        ssbrae_config.trans, ssbrae_config.para_num, ssbrae_config.trans_num,
        sess)

    train_phrase_list = bi_phrase_list[:-2 * ssbrae_encoder.batch_size]
    valid_phrase_list = bi_phrase_list[-2 * ssbrae_encoder.
                                       batch_size:-ssbrae_encoder.batch_size]
    test_phrase_list = bi_phrase_list[-ssbrae_encoder.batch_size:]

    pre_logger("ssbrae")
    logger.info("Now train ssbrae encoder\n")
    for i in range(ssbrae_encoder.n_epoch):
        logger.info("Now train ssbrae encoder epoch %d\n" % i)
        start_time = time.time()
        losses = []

        train_phrase_index = get_train_sequence(train_phrase_list,
                                                ssbrae_encoder.batch_size)
        num_batches = int(len(train_phrase_index) / ssbrae_encoder.batch_size)
        for j in range(num_batches):
            (src_pos, tar_pos, src_neg, tar_neg, src_para, tar_para, src_para_weight, tar_para_weight, src_tar_trans,\
             tar_src_trans, src_tar_trans_weight, tar_src_trans_weight) = ssbrae_encoder.get_batch(src_phrase_list,
                                                                                                  tar_phrase_list,
                                                                                                  train_phrase_list,
                                                                                                  train_phrase_index,
                                                                                                  src_word_idx,
                                                                                                  tar_word_idx, j)
            result = ssbrae_encoder.ssbrae_train_step(
                src_pos, tar_pos, src_neg, tar_neg, src_para, tar_para,
                src_para_weight, tar_para_weight, src_tar_trans, tar_src_trans,
                src_tar_trans_weight, tar_src_trans_weight)
            if ssbrae_encoder.para and ssbrae_encoder.trans:
                logger.info(
                    "train ssbrae_para epoch %d, step %d, total loss:%f, loss_l2: %f, loss_rec: %f,"
                    "loss_sem:%f, loss_para:%f, loss_trans:%f\n" %
                    (i, j, result[1], result[2], result[3], result[4],
                     result[5], result[6]))
            elif ssbrae_encoder.para and not ssbrae_encoder.trans:
                logger.info(
                    "train ssbrae_para epoch %d, step %d, total loss:%f, loss_l2: %f, loss_rec: %f,"
                    "loss_sem:%f, loss_para:%f\n" %
                    (i, j, result[1], result[2], result[3], result[4],
                     result[5]))
            elif ssbrae_encoder.trans and not ssbrae_encoder.para:
                logger.info(
                    "train ssbrae_para epoch %d, step %d, total loss:%f, loss_l2: %f, loss_rec: %f,"
                    "loss_sem:%f, loss_trans:%f\n" %
                    (i, j, result[1], result[2], result[3], result[4],
                     result[5]))
            else:
                raise ValueError("No such configuration")
            losses.append(result[1:])

        use_time = time.time() - start_time

        valid_phrase_index = get_train_sequence(valid_phrase_list,
                                                ssbrae_encoder.batch_size)
        num_batches = int(len(valid_phrase_index) / ssbrae_encoder.batch_size)
        dev_loss = []
        for j in range(num_batches):
            (src_pos, tar_pos, src_neg, tar_neg, src_para, tar_para, src_para_weight, tar_para_weight, src_tar_trans, \
             tar_src_trans, src_tar_trans_weight, tar_src_trans_weight) = ssbrae_encoder.get_batch(src_phrase_list,
                                                                                                   tar_phrase_list,
                                                                                                   valid_phrase_list,
                                                                                                   valid_phrase_index,
                                                                                                   src_word_idx,
                                                                                                   tar_word_idx, j)
            dev_loss.append(
                ssbrae_encoder.ssbrae_predict_step(
                    src_pos, tar_pos, src_neg, tar_neg, src_para, tar_para,
                    src_para_weight, tar_para_weight, src_tar_trans,
                    tar_src_trans, src_tar_trans_weight, tar_src_trans_weight))
        logger.info("train ssbrae encoder epoch %d, use time:%d\n" %
                    (i, use_time))
        ave_train_loss = np.average(losses, axis=0)
        ave_dev_loss = np.average(dev_loss, axis=0)
        if ssbrae_encoder.para and ssbrae_encoder.trans:
            logger.info(
                "train: total loss:%f, l2 loss:%f, rec loss:%f, sem loss:%f, para loss:%f, trans loss:%f\n"
                % (ave_train_loss[0], ave_train_loss[1], ave_train_loss[2],
                   ave_train_loss[3], ave_train_loss[4], ave_train_loss[5]))
            logger.info(
                "dev: total loss:%f, l2 loss:%f, rec loss:%f, sem loss:%f, para loss:%f, trans loss:%f"
                % (ave_dev_loss[0], ave_dev_loss[1], ave_dev_loss[2],
                   ave_dev_loss[3], ave_dev_loss[4], ave_dev_loss[5]))
        elif ssbrae_encoder.para and not ssbrae_encoder.trans:
            logger.info(
                "train: total loss:%f, l2 loss:%f, rec loss:%f, sem loss:%f, para loss:%f\n"
                % (ave_train_loss[1], ave_train_loss[2], ave_train_loss[3],
                   ave_train_loss[4], ave_train_loss[5]))
            logger.info(
                "dev: total loss:%f, l2 loss:%f, rec loss:%f, sem loss:%f, para loss:%f"
                % (ave_dev_loss[0], ave_dev_loss[1], ave_dev_loss[2],
                   ave_dev_loss[3], ave_dev_loss[4]))
        elif ssbrae_encoder.trans and not ssbrae_encoder.para:
            logger.info(
                "train: total loss:%f, l2 loss:%f, rec loss:%f, sem loss:%f, trans loss:%f\n"
                % (ave_train_loss[1], ave_train_loss[2], ave_train_loss[3],
                   ave_train_loss[4], ave_train_loss[5]))
            logger.info(
                "dev: total loss:%f, l2 loss:%f, rec loss:%f, sem loss:%f, trans loss:%f"
                % (ave_dev_loss[0], ave_dev_loss[1], ave_dev_loss[2],
                   ave_dev_loss[3], ave_dev_loss[4]))

        checkpoint_path = os.path.join(FLAGS.train_dir,
                                       "ssbare_encoder.epoch%d.ckpt" % i)
        #ssbrae_encoder.saver.save(ssbrae_encoder.sess, checkpoint_path, global_step=ssbrae_encoder.global_step)
        ssbrae_encoder.saver.save(ssbrae_encoder.sess, checkpoint_path)

    test_phrase_index = get_train_sequence(test_phrase_list,
                                           ssbrae_encoder.batch_size)
    num_batches = int(len(test_phrase_index) / ssbrae_encoder.batch_size)
    test_loss = []
    for j in range(num_batches):
        (src_pos, tar_pos, src_neg, tar_neg, src_para, tar_para, src_para_weight, tar_para_weight, src_tar_trans, \
         tar_src_trans, src_tar_trans_weight, tar_src_trans_weight) = ssbrae_encoder.get_batch(src_phrase_list,
                                                                                               tar_phrase_list,
                                                                                               test_phrase_list,
                                                                                               test_phrase_index,
                                                                                               src_word_idx,
                                                                                               tar_word_idx, j)
        test_loss.append(
            ssbrae_encoder.ssbrae_predict_step(
                src_pos, tar_pos, src_neg, tar_neg, src_para, tar_para,
                src_para_weight, tar_para_weight, src_tar_trans, tar_src_trans,
                src_tar_trans_weight, tar_src_trans_weight))

    ave_test_loss = np.average(test_loss, axis=0)
    if ssbrae_encoder.para and ssbrae_encoder.trans:
        logger.info(
            "test: total loss:%f, l2 loss:%f, rec loss:%f, sem loss:%f, para loss:%f, trans loss:%f"
            % (ave_test_loss[0], ave_test_loss[1], ave_test_loss[2],
               ave_test_loss[3], ave_test_loss[4], ave_test_loss[5]))
    elif ssbrae_encoder.para and not ssbrae_encoder.trans:
        logger.info(
            "test: total loss:%f, l2 loss:%f, rec loss:%f, sem loss:%f, para loss:%f"
            % (ave_test_loss[0], ave_test_loss[1], ave_test_loss[2],
               ave_test_loss[3], ave_test_loss[4]))
    elif ssbrae_encoder.trans and not ssbrae_encoder.para:
        logger.info(
            "test: total loss:%f, l2 loss:%f, rec loss:%f, sem loss:%f, trans loss:%f"
            % (ave_test_loss[0], ave_test_loss[1], ave_test_loss[2],
               ave_test_loss[3], ave_test_loss[4]))
Example #9
0
def main():
    train_test = sys.argv[1]
    if train_test not in ["train", "predict"]:
        sys.stderr("train or predict")
        exit(1)
    config_name = sys.argv[2]
    forced_decode_data = "../gbrae/data/250w/phrase-table.filtered"
    phrase_data_path = "data/phrase.list"
    brae_config = BRAEConfig(config_name)
    train_name = "dim%d_lrec%f_lsem%f_ll2%f_alpha%f_seed%d_batch%d_min%d_lr%f" % (brae_config.dim,
                                                                                  brae_config.weight_rec,
                                                                                  brae_config.weight_sem,
                                                                                  brae_config.weight_l2,
                                                                                  brae_config.alpha,
                                                                                  brae_config.random_seed,
                                                                                  brae_config.batch_size,
                                                                                  brae_config.min_count,
                                                                                  brae_config.optimizer.param["lr"],)
    model_name = "model/%s" % train_name
    temp_model = model_name + ".temp"
    if train_test == "train":
        start_iter = int(sys.argv[3]) if len(sys.argv) > 3 else 0
        end_iter = int(sys.argv[4]) if len(sys.argv) > 4 else 26
        pre_logger("brae_" + train_name)
        np.random.seed(brae_config.random_seed)
        if start_iter == 0:
            print "Load Dict ..."
            en_embedding_name = "../gbrae/data/embedding/en.token.dim%d.bin" % brae_config.dim
            zh_embedding_name = "../gbrae/data/embedding/zh.token.dim%d.bin" % brae_config.dim
            tar_word_dict = WordEmbedding.load_word2vec_word_map(en_embedding_name, binary=True, oov=True)
            src_word_dict = WordEmbedding.load_word2vec_word_map(zh_embedding_name, binary=True, oov=True)
            print "Compiling Model ..."
            brae = pre_model(src_word_dict, tar_word_dict, brae_config, verbose=True)
            print "Load All Data ..."
            src_phrases, tar_phrases, src_tar_pair = read_phrase_list(forced_decode_data, src_word_dict, tar_word_dict)
            src_train = [p[WORD_INDEX] for p in src_phrases]
            tar_train = [p[WORD_INDEX] for p in tar_phrases]
            print "Write Binary Data ..."
            with open(temp_model, 'wb') as fout:
                pickle.dump(src_train, fout)
                pickle.dump(tar_train, fout)
                pickle.dump(src_tar_pair, fout)
                pickle.dump(brae, fout)
                pickle.dump(np.random.get_state(), fout)
            if end_iter == 1:
                exit(1)
        else:
            with open(temp_model, 'rb') as fin:
                src_train = pickle.load(fin)
                tar_train = pickle.load(fin)
                src_tar_pair = pickle.load(fin)
                brae = pickle.load(fin)
                np.random.set_state(pickle.load(fin))
        brae.train(src_train, tar_train, src_tar_pair, brae_config, model_name, start_iter, end_iter)
        brae.save_model("%s.model" % model_name)
    elif train_test == "predict":
        num_process = int(sys.argv[3]) if len(sys.argv) > 3 else 0
        brae_predict(phrase_data_path, train_name + ".pred", model_file="%s.model" % model_name, num_process=num_process)
    else:
        sys.stderr("train or predict")
        exit(1)
Example #10
0
def train(config_file, pre_train_word_count_file, emotion_words_dir, post_file, response_file, emotion_label_file,
          embedding_file, train_word_count, session, checkpoint_dir, max_vocab_size, test_post_file, test_label_file):
    """
    train the dialogue model
    :param config_file:
    :param pre_train_word_count_file:
    :param emotion_words_dir:
    :param post_file:
    :param response_file:
    :param emotion_label_file:
    :param embedding_file:
    :param train_word_count:
    :param session:
    :param checkpoint_dir:
    :param max_vocab_size:
    :param test_post_file:
    :param test_label_file:
    :return:
    """
    log_name = "dialogue"
    logger = pre_logger(log_name)

    chat_config = ChatConfig(config_file)

    logger.info("Now prepare data!\n")
    logger.info("Read stop words!\n")
    # stop_words = read_stop_words(FLAGS.stop_words_file)

    logger.info("Construct vocab first\n")
    total_embeddings, total_word2id, total_word_list = read_total_embeddings(embedding_file, max_vocab_size)
    pre_word_count = get_word_count(pre_train_word_count_file, chat_config.word_count)
    emotion_words_dict = read_emotion_words(emotion_words_dir, pre_word_count)
    word_list = construct_vocab(total_word_list, emotion_words_dict, chat_config.generic_word_size,
                                chat_config.emotion_vocab_size, FLAGS.unk)
    word_dict = construct_word_dict(word_list, FLAGS.unk, FLAGS.start_symbol, FLAGS.end_symbol)
    id2words = {idx: word for word, idx in word_dict.items()}
    word_unk_id = word_dict[FLAGS.unk]
    word_start_id = word_dict[FLAGS.start_symbol]
    word_end_id = word_dict[FLAGS.end_symbol]
    final_word_list = get_word_list(id2words)

    logger.info("Read word embeddings!\n")
    embeddings = read_word_embeddings(total_embeddings, total_word2id, final_word_list, chat_config.embedding_size)

    logger.info("Read training data!\n")
    train_post_data = read_training_file(post_file, word_dict, FLAGS.unk)
    train_response_data = read_training_file(response_file, word_dict, FLAGS.unk)
    emotion_labels = read_emotion_label(emotion_label_file)

    logger.info("Filter training data according to length!\n")
    train_post_data, train_response_data, emotion_labels = filter_sentence_length(train_post_data, train_response_data,
                                                                                  emotion_labels, chat_config.min_len,
                                                                                  chat_config.max_len)
    logger.info("Number of length <= 10 sentences: %d\n" % len(train_post_data))
    train_post_length = [len(post_data) for post_data in train_post_data]

    logger.info("Align sentence length by padding!\n")
    train_post_data = align_sentence_length(train_post_data, chat_config.max_len, word_unk_id)
    train_response_data, predict_response_data = get_predict_train_response_data(train_response_data, word_start_id,
                                                                                 word_end_id, word_unk_id,
                                                                                 chat_config.max_len)

    train_post_data, train_post_length, train_response_data, predict_response_data, emotion_labels = \
        align_batch_size(train_post_data, train_post_length, train_response_data, predict_response_data, emotion_labels,
                         chat_config.batch_size)
    logger.info("Finish preparing data!\n")

    logger.info("Read test data\n")
    test_post_data = read_training_file(test_post_file, word_dict, FLAGS.unk)
    test_label_data = read_emotion_label(test_label_file)
    test_length = len(test_post_data)

    logger.info("filter test post data length!\n")
    test_post_data, test_label_data = filter_test_sentence_length(test_post_data, test_label_data, chat_config.min_len,
                                                                  chat_config.max_len)

    logger.info("Number of length <= 10 sentences: %d\n" % len(test_post_data))
    test_post_data_length = [len(post_data) for post_data in test_post_data]

    logger.info("Align sentence length by padding!\n")
    test_post_data = align_sentence_length(test_post_data, chat_config.max_len, word_unk_id)
    test_post_data, test_post_data_length, test_label_data = \
        align_test_batch_size(test_post_data, test_post_data_length, test_label_data, chat_config.batch_size)

    logger.info("Define model\n")
    emotion_chat_machine = EmotionChatMachine(config_file, session, word_dict, embeddings,
                                              chat_config.generic_word_size + 3, word_start_id, word_end_id,
                                              "emotion_chat_machine")
    checkpoint_path = os.path.join(checkpoint_dir, "dialogue-model")

    num_train_batch = int(len(train_post_data) / chat_config.batch_size)
    train_epochs = chat_config.epochs_to_train
    
    logger.info("Start training\n")
    for i in range(train_epochs):
        if i != 0 and i % 3 == 0:
            session.run(emotion_chat_machine.lr_decay_op)

        logger.info("Training epoch %d\n" % (i + 1))
        train_post_data, train_post_length, train_response_data, predict_response_data, emotion_labels = \
            shuffle_train_data(train_post_data, train_post_length, train_response_data, predict_response_data,
                               emotion_labels)

        for j in range(num_train_batch):
            this_post_data, this_post_len, this_train_res_data, this_predict_res_data, this_emotion_labels, \
             this_emotion_mask = emotion_chat_machine.get_batch(train_post_data, train_post_length, train_response_data,
                                                                predict_response_data, emotion_labels, j)
            loss = emotion_chat_machine.train_step(this_post_data, this_post_len, this_train_res_data,
                                                   this_predict_res_data, this_emotion_labels, this_emotion_mask)
            entropy_loss, reg_loss, total_loss = loss
            logger.info("Epoch=%d, batch=%d, total loss=%f, entropy loss=%f, reg_loss=%f\n" %
                        ((i + 1), (j + 1), total_loss, entropy_loss, reg_loss))

        logger.info("Saving parameters\n")
        emotion_chat_machine.saver.save(emotion_chat_machine.session, checkpoint_path,
                                        global_step=(i * num_train_batch))
Example #11
0
def generate_response(config_file, pre_train_word_count_file,
                      emotion_words_dir, embedding_file, session,
                      checkpoint_dir, max_vocab_size, test_post_file,
                      test_label_file, log_name, restore_model):
    """
    generate response from checkpoint
    :param config_file:
    :param pre_train_word_count_file:
    :param emotion_words_dir:
    :param embedding_file:
    :param session:
    :param checkpoint_dir:
    :param max_vocab_size:
    :param test_post_file:
    :param test_label_file:
    :param log_name:
    :param restore_model:
    :return:
    """
    logger = pre_logger(log_name)

    chat_config = ChatConfig(config_file)

    logger.info("Now prepare data!\n")
    logger.info("Read stop words!\n")
    stop_words = read_stop_words(FLAGS.stop_words_file)

    logger.info("Construct vocab first\n")
    total_embeddings, total_word2id, total_word_list = read_total_embeddings(
        embedding_file, max_vocab_size)
    pre_word_count = get_word_count(pre_train_word_count_file,
                                    chat_config.word_count)
    emotion_words_dict = read_emotion_words(emotion_words_dir, pre_word_count)
    word_list = construct_vocab(total_word_list, emotion_words_dict,
                                chat_config.generic_word_size,
                                chat_config.emotion_vocab_size, FLAGS.unk)
    word_dict = construct_word_dict(word_list, FLAGS.unk, FLAGS.start_symbol,
                                    FLAGS.end_symbol)
    id2words = {idx: word for word, idx in word_dict.items()}
    word_unk_id = word_dict[FLAGS.unk]
    word_start_id = word_dict[FLAGS.start_symbol]
    word_end_id = word_dict[FLAGS.end_symbol]
    final_word_list = get_word_list(id2words)

    logger.info("Read word embeddings!\n")
    embeddings = read_word_embeddings(total_embeddings, total_word2id,
                                      final_word_list,
                                      chat_config.embedding_size)

    logger.info("Read test data\n")
    test_post_data = read_training_file(test_post_file, word_dict, FLAGS.unk)
    test_label_data = read_emotion_label(test_label_file)

    logger.info("filter test post data length!\n")
    test_post_data, test_label_data = filter_test_sentence_length(
        test_post_data, test_label_data, chat_config.min_len,
        chat_config.max_len)

    logger.info("Number of length <= 10 sentences: %d\n" % len(test_post_data))
    test_post_data_length = [len(post_data) for post_data in test_post_data]
    test_length = len(test_post_data)

    logger.info("Align sentence length by padding!\n")
    test_post_data = align_sentence_length(test_post_data, chat_config.max_len,
                                           word_unk_id)
    test_post_data, test_post_data_length, test_label_data = \
        align_test_batch_size(test_post_data, test_post_data_length, test_label_data, chat_config.batch_size)

    logger.info("Define model\n")
    emotion_chat_machine = EmotionChatMachine(
        config_file, session, word_dict, embeddings,
        chat_config.generic_word_size + 3, word_start_id, word_end_id,
        "emotion_chat_machine")
    checkpoint_path = os.path.join(checkpoint_dir, restore_model)
    emotion_chat_machine.saver.restore(session, checkpoint_path)

    logger.info("Generate test data!\n")
    test_batch = int(len(test_post_data) / chat_config.batch_size)
    generate_data = []
    for k in range(test_batch):
        this_post_data, this_post_len, this_emotion_labels, this_emotion_mask = \
            emotion_chat_machine.get_test_batch(test_post_data, test_post_data_length, test_label_data, k)
        generate_words, scores, new_embeddings = emotion_chat_machine.generate_step(
            this_post_data, this_post_len, this_emotion_labels,
            this_emotion_mask)
        # generate words: [batch, beam, max len]  &&  scores: [batch, beam]
        best_generate_sen = select_best_response(
            generate_words, scores, this_post_data, this_emotion_labels,
            emotion_words_dict, chat_config.batch_size, stop_words,
            word_start_id, word_end_id)
        generate_data.extend(best_generate_sen)
    generate_data = generate_data[:test_length]
    test_label_data = test_label_data[:test_length]
    write_test_data(generate_data, FLAGS.generate_response_file, id2words,
                    test_label_data)
Example #12
0
def main():
    train_test = sys.argv[1]
    if train_test not in ["train", "predict"]:
        sys.stderr("train or predict")
        exit(1)
    config_name = sys.argv[2]
    phrase_data_path = "data/phrase.list"
    brae_config = GBRAEConfig(config_name)
    gbrae_data_name = "data/250w/gbrae.data.250w.min.count.%d.pkl" % brae_config.min_count
    gbrae_dict_name = "data/250w/gbrae.dict.250w.min.count.%d.pkl" % brae_config.min_count
    #gbrae_data_name = "data/250w/tune_hyperparameter/train/gbrae.data.5w.tune.min.count.%d.pkl" % brae_config.min_count
    #gbrae_dict_name = "data/250w/tune_hyperparameter/train/gbrae.dict.5w.tune.min.count.%d.pkl" % brae_config.min_count
    if brae_config.para and brae_config.trans:
        train_name = "dim%d_lrec%f_lsem%f_lword%f_alpha%f_beta%f_gama%f_num%d_seed%d_batch%d_min%d_lr%f" % (
            brae_config.dim, brae_config.weight_rec, brae_config.weight_sem,
            brae_config.weight_l2, brae_config.alpha, brae_config.beta,
            brae_config.gama, brae_config.trans_num, brae_config.random_seed,
            brae_config.batch_size, brae_config.min_count,
            brae_config.optimizer.param["lr"])
    elif brae_config.para:
        train_name = "para_dim%d_lrec%f_lsem%f_lword%f_alpha%f_beta%f_gama%f_num%d_seed%d_batch%d_min%d_lr%f" % (
            brae_config.dim, brae_config.weight_rec, brae_config.weight_sem,
            brae_config.weight_l2, brae_config.alpha, brae_config.beta,
            brae_config.gama, brae_config.trans_num, brae_config.random_seed,
            brae_config.batch_size, brae_config.min_count,
            brae_config.optimizer.param["lr"])
    elif brae_config.trans:
        train_name = "trans_dim%d_lrec%f_lsem%f_lword%f_alpha%f_beta%f_gama%f_num%d_seed%d_batch%d_min%d_lr%f" % (
            brae_config.dim, brae_config.weight_rec, brae_config.weight_sem,
            brae_config.weight_l2, brae_config.alpha, brae_config.beta,
            brae_config.gama, brae_config.trans_num, brae_config.random_seed,
            brae_config.batch_size, brae_config.min_count,
            brae_config.optimizer.param["lr"])

    model_name = "model/%s" % train_name
    pre_train_model_file_name = None
    temp_model = model_name + ".temp"
    if train_test == "train":
        start_iter = int(sys.argv[3]) if len(sys.argv) > 3 else 0
        end_iter = int(sys.argv[4]) if len(sys.argv) > 4 else 26
        if len(sys.argv) > 5:
            pre_train_model_file_name = sys.argv[5]
            model_name += "_pred_%s" % pre_train_model_file_name
        pre_logger("gbrae_" + train_name)
        np.random.seed(brae_config.random_seed)
        if start_iter == 0:
            print "Load Dict ..."
            src_word_dict, tar_word_dict = load_gbrae_dict(gbrae_dict_name)
            print "Compiling Model ..."
            brae = pre_model(src_word_dict,
                             tar_word_dict,
                             brae_config,
                             verbose=True)
            if pre_train_model_file_name is not None:
                brae.load_model(pre_train_model_file_name)
            print "Write Binary Data ..."
            with open(temp_model, 'wb') as fout:
                pickle.dump(brae, fout)
                pickle.dump(np.random.get_state(), fout)
            if end_iter == 1:
                exit(1)
        else:
            with open(temp_model, 'rb') as fin:
                brae = pickle.load(fin)
                np.random.set_state(pickle.load(fin))
        src_phrases, tar_phrases, src_tar_pair = load_gbrae_data(
            gbrae_data_name)
        brae.train(src_phrases, tar_phrases, src_tar_pair, brae_config,
                   model_name, start_iter, end_iter)
        brae.save_model("%s.model" % model_name)
    elif train_test == "predict":
        num_process = int(sys.argv[3]) if len(sys.argv) > 3 else 0
        if len(sys.argv) > 4:
            pre_train_model_file_name = sys.argv[4]
            model_name += "_pred_%s" % pre_train_model_file_name
        brae_predict(phrase_data_path,
                     train_name + ".pred",
                     model_file="%s.model" % model_name,
                     bilinear=False,
                     num_process=num_process)
    else:
        sys.stderr("train or predict")
        exit(1)
Example #13
0
    if args.norm_limit > 0:
        model_name += "_nlimit%s" % args.norm_limit
    if args.regular > 0:
        model_name += "_regular%s" % args.regular
    if args.emblr < 0:
        EMBEDDING_LR = -1
    elif args.emblr != EMBEDDING_LR:
        model_name += "_emblr%s" % args.emblr
    else:
        model_name += "_emblr%s" % EMBEDDING_LR
    if args.cross > 1:
        model_name += "_cross%s" % args.cross
        CROSS_VALIDATION_TIMES = args.cross
    if args.prefix is not None:
        model_name = "%s_%s" % (args.prefix, model_name)
    pre_logger(model_name)
    np.random.seed(args.seed)

    if args.epoch != MAX_ITER:
        MAX_ITER = args.epoch
        logger.info("Max Epoch Iter: %s" % MAX_ITER)
    if args.emblr != EMBEDDING_LR:
        EMBEDDING_LR = args.emblr
        logger.info("Embedding Learning Rate: %s" % EMBEDDING_LR)
    if args.test_batch != PRE_TEST_BATCH:
        PRE_TEST_BATCH = args.test_batch
    if args.label is not None:
        LABEL2INDEX = args.label

    if CROSS_VALIDATION_TIMES > 1:
        args.dev = None