Example #1
0
def train(config):
    model_config_path = config["model_config_path"]
    FLAGS = namespace_utils.load_namespace(model_config_path)

    os.environ["CUDA_VISIBLE_DEVICES"] = config.get("gpu_id", "")
    train_path = config["train_path"]
    w2v_path = config["w2v_path"]
    vocab_path = config["vocab_path"]
    dev_path = config["dev_path"]
    elmo_w2v_path = config.get("elmo_w2v_path", None)
    label_emb_path = config.get("label_emb_path", None)

    if label_emb_path:
        import pickle as pkl
        label_emb_mat = pkl.load(open(label_emb_path, "rb"))

    model_dir = config["model_dir"]
    try:
        model_name = FLAGS["output_folder_name"]
    except:
        model_name = config["model"]
    print(model_name, "====model name====")
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    if not os.path.exists(os.path.join(model_dir, model_name)):
        os.mkdir(os.path.join(model_dir, model_name))

    if not os.path.exists(os.path.join(model_dir, model_name, "logs")):
        os.mkdir(os.path.join(model_dir, model_name, "logs"))

    if not os.path.exists(os.path.join(model_dir, model_name, "models")):
        os.mkdir(os.path.join(model_dir, model_name, "models"))

    json.dump(
        FLAGS,
        open(os.path.join(model_dir, model_name, "logs", model_name + ".json"),
             "w"))
    logger = logger_utils.get_logger(
        os.path.join(model_dir, model_name, "logs", "log.info"))

    [
        train_anchor, train_check, train_label, train_anchor_len,
        train_check_len, embedding_info
    ] = prepare_data(train_path,
                     w2v_path,
                     vocab_path,
                     make_vocab=True,
                     elmo_w2v_path=elmo_w2v_path,
                     elmo_pca=FLAGS.elmo_pca,
                     emb_idf=config.emb_idf)

    [
        dev_anchor, dev_check, dev_label, dev_anchor_len, dev_check_len,
        embedding_info
    ] = prepare_data(dev_path,
                     w2v_path,
                     vocab_path,
                     make_vocab=False,
                     elmo_w2v_path=elmo_w2v_path,
                     elmo_pca=FLAGS.elmo_pca,
                     emb_idf=config.emb_idf)

    token2id = embedding_info["token2id"]
    id2token = embedding_info["id2token"]
    embedding_mat = embedding_info["embedding_matrix"]
    extral_symbol = embedding_info["extra_symbol"]

    if config.emb_idf:
        idf_emb_mat = embedding_info["idf_matrix"]
        FLAGS.idf_emb_mat = idf_emb_mat
        FLAGS.with_idf = True

    FLAGS.token_emb_mat = embedding_mat
    FLAGS.char_emb_mat = 0
    FLAGS.vocab_size = embedding_mat.shape[0]
    FLAGS.char_vocab_size = 0
    FLAGS.emb_size = embedding_mat.shape[1]
    FLAGS.extra_symbol = extral_symbol
    FLAGS.class_emb_mat = label_emb_mat

    if FLAGS.scope == "ESIM":
        model = ESIM()
    elif FLAGS.scope == "BiBLOSA":
        model = BiBLOSA()
    elif FLAGS.scope == "BaseTransformer":
        model = BaseTransformer()
    elif FLAGS.scope == "UniversalTransformer":
        model = UniversalTransformer()
    elif FLAGS.scope == "Capsule":
        model = Capsule()
    elif FLAGS.scope == "LabelNetwork":
        model = LabelNetwork()
    elif FLAGS.scope == "LEAM":
        model = LEAM()
    elif FLAGS.scope == "SWEM":
        model = SWEM()
    elif FLAGS.scope == "TextCNN":
        model = TextCNN()
    elif FLAGS.scope == "DeepPyramid":
        model = DeepPyramid()
    elif FLAGS.scope == "ReAugument":
        model = ReAugument()

    if FLAGS.scope in ["Capsule", "DeepPyramid"]:
        max_anchor_len = FLAGS.max_length
        max_check_len = 1
        if_max_anchor_len = True,
        if_max_check_len = True
    else:
        max_anchor_len = FLAGS.max_length
        max_check_len = 1
        if_max_anchor_len = False
        if_max_check_len = False

    model.build_placeholder(FLAGS)
    model.build_op()
    model.init_step()

    best_dev_f1 = 0.0
    best_dev_loss = 100.0
    learning_rate = FLAGS.learning_rate
    toleration = 1000
    toleration_cnt = 0
    print("=======begin to train=========")
    for epoch in range(FLAGS.max_epochs):
        train_loss, train_accuracy = 0, 0
        train_data = get_batch_data.get_batches(
            train_anchor,
            train_check,
            train_label,
            FLAGS.batch_size,
            token2id,
            is_training=True,
            max_anchor_len=max_anchor_len,
            if_max_anchor_len=if_max_anchor_len,
            max_check_len=max_check_len,
            if_max_check_len=if_max_check_len)

        nan_data = []
        cnt = 0
        train_accuracy_score, train_precision_score, train_recall_score = 0, 0, 0
        train_label_lst, train_true_lst = [], []

        for index, corpus in enumerate(train_data):
            anchor, entity, label = corpus
            assert entity.shape[-1] == 1
            try:
                [loss, _, global_step, accuracy,
                 preds] = model.step([anchor, entity, label],
                                     is_training=True,
                                     learning_rate=learning_rate)

                import math
                if math.isnan(loss):
                    print(anchor, entity, label, loss, "===nan loss===")
                    break
                train_label_lst += np.argmax(preds, axis=-1).tolist()
                train_true_lst += label.tolist()

                train_loss += loss * anchor.shape[0]
                train_accuracy += accuracy * anchor.shape[0]
                cnt += anchor.shape[0]
            except:
                continue

        train_loss /= float(cnt)

        train_accuracy = accuracy_score(train_true_lst, train_label_lst)
        train_recall = recall_score(train_true_lst,
                                    train_label_lst,
                                    average='macro')
        train_precision = precision_score(train_true_lst,
                                          train_label_lst,
                                          average='macro')
        train_f1 = f1_score(train_true_lst, train_label_lst, average='macro')

        info = OrderedDict()
        info["epoch"] = str(epoch)
        info["train_loss"] = str(train_loss)
        info["train_accuracy"] = str(train_accuracy)
        info["train_f1"] = str(train_f1)

        logger.info("epoch\t{}\ttrain\tloss\t{}\taccuracy\t{}\tf1\t{}".format(
            epoch, train_loss, train_accuracy, train_f1))

        dev_data = get_batch_data.get_batches(
            dev_anchor,
            dev_check,
            dev_label,
            FLAGS.batch_size,
            token2id,
            is_training=False,
            max_anchor_len=max_anchor_len,
            if_max_anchor_len=if_max_anchor_len,
            max_check_len=max_check_len,
            if_max_check_len=if_max_check_len)

        dev_loss, dev_accuracy = 0, 0
        cnt = 0
        dev_label_lst, dev_true_lst = [], []
        for index, corpus in enumerate(dev_data):
            anchor, entity, label = corpus

            try:
                [loss, logits, pred_probs,
                 accuracy] = model.infer([anchor, entity, label],
                                         mode="test",
                                         is_training=False,
                                         learning_rate=learning_rate)

                dev_label_lst += np.argmax(pred_probs, axis=-1).tolist()
                dev_true_lst += label.tolist()

                import math
                if math.isnan(loss):
                    print(anchor, entity, pred_probs, index)

                dev_loss += loss * anchor.shape[0]
                dev_accuracy += accuracy * anchor.shape[0]
                cnt += anchor.shape[0]
            except:
                continue

        dev_loss /= float(cnt)

        dev_accuracy = accuracy_score(dev_true_lst, dev_label_lst)
        dev_recall = recall_score(dev_true_lst, dev_label_lst, average='macro')
        dev_precision = precision_score(dev_true_lst,
                                        dev_label_lst,
                                        average='macro')
        dev_f1 = f1_score(dev_true_lst, dev_label_lst, average='macro')

        info["dev_loss"] = str(dev_loss)
        info["dev_accuracy"] = str(dev_accuracy)
        info["dev_f1"] = str(dev_f1)

        logger.info("epoch\t{}\tdev\tloss\t{}\taccuracy\t{}\tf1\t{}".format(
            epoch, dev_loss, dev_accuracy, dev_f1))

        if dev_f1 > best_dev_f1 or dev_loss < best_dev_loss:
            timestamp = str(int(time.time()))
            model.save_model(
                os.path.join(model_dir, model_name, "models"),
                model_name + "_{}_{}_{}".format(timestamp, dev_loss, dev_f1))
            best_dev_f1 = dev_f1
            best_dev_loss = dev_loss

            toleration_cnt = 0

            info["best_dev_loss"] = str(dev_loss)
            info["dev_f1"] = str(dev_f1)

            logger_utils.json_info(
                os.path.join(model_dir, model_name, "logs", "info.json"), info)
            logger.info(
                "epoch\t{}\tbest_dev\tloss\t{}\tbest_accuracy\t{}\tbest_f1\t{}"
                .format(epoch, dev_loss, dev_accuracy, best_dev_f1))
        else:
            toleration_cnt += 1
            if toleration_cnt == toleration:
                toleration_cnt = 0
                learning_rate *= 0.5
Example #2
0
def train(config):
    model_config_path = config["model_config_path"]
    FLAGS = namespace_utils.load_namespace(model_config_path)

    os.environ["CUDA_VISIBLE_DEVICES"] = config.get("gpu_id", "")
    train_path = config["train_path"]
    w2v_path = config["w2v_path"]
    vocab_path = config["vocab_path"]
    dev_path = config["dev_path"]

    model_dir = config["model_dir"]
    model_name = config["model"]

    if not os.path.exists(model_dir):
        os.mkdir(model_dir)

    if not os.path.exists(os.path.join(model_dir, model_name)):
        os.mkdir(os.path.join(model_dir, model_name))

    if not os.path.exists(os.path.join(model_dir, model_name, "logs")):
        os.mkdir(os.path.join(model_dir, model_name, "logs"))

    if not os.path.exists(os.path.join(model_dir, model_name, "models")):
        os.mkdir(os.path.join(model_dir, model_name, "models"))

    json.dump(
        FLAGS,
        open(os.path.join(model_dir, model_name, "logs", model_name + ".json"),
             "w"))
    logger = logger_utils.get_logger(
        os.path.join(model_dir, model_name, "logs", "log.info"))

    [train_corpus, train_corpus_label, train_corpus_len,
     embedding_info] = prepare_data(train_path,
                                    w2v_path,
                                    vocab_path,
                                    make_vocab=True)

    token2id = embedding_info["token2id"]
    id2token = embedding_info["id2token"]
    embedding_mat = embedding_info["embedding_matrix"]
    extral_symbol = embedding_info["extra_symbol"]

    FLAGS.token_emb_mat = embedding_mat
    FLAGS.char_emb_mat = 0
    FLAGS.vocab_size = embedding_mat.shape[0]
    FLAGS.char_vocab_size = 0
    FLAGS.emb_size = embedding_mat.shape[1]
    FLAGS.extra_symbol = extral_symbol

    if FLAGS.scope == "ESIM":
        model = ESIM()
    elif FLAGS.scope == "BiBLOSA":
        model = BiBLOSA()
    elif FLAGS.scope == "BaseTransformer":
        model = BaseTransformer()
    elif FLAGS.scope == "UniversalTransformer":
        model = UniversalTransformer()

    model.build_placeholder(FLAGS)
    model.build_op()
    model.init_step()

    best_train_accuracy, best_train_loss = 0, 100
    toleration = 1000
    toleration_cnt = 0
    for epoch in range(FLAGS.max_epochs):
        train_loss, train_accuracy = 0, 0
        train_data = get_batch_data.get_classify_batch(
            train_corpus,
            train_corpus_label,
            FLAGS.batch_size,
            token2id,
            is_training=True,
            if_word_drop=FLAGS.with_word_drop,
            word_drop_rate=FLAGS.word_drop_rate)

        nan_data = []
        cnt = 0
        for index, corpus in enumerate(train_data):
            anchor, label = corpus
            try:
                [loss, _, global_step, accuracy,
                 preds] = model.step([anchor, label], is_training=True)

                train_loss += loss * anchor.shape[0]
                train_accuracy += accuracy * anchor.shape[0]
                cnt += anchor.shape[0]
            except:
                continue

        train_loss /= float(cnt)
        train_accuracy /= float(cnt)

        info = OrderedDict()
        info["epoch"] = str(epoch)
        info["train_loss"] = str(train_loss)
        info["train_accuracy"] = str(train_accuracy)

        logger.info("epoch\t{}\ttrain\tloss\t{}\taccuracy\t{}".format(
            epoch, train_loss, train_accuracy))

        if train_accuracy > best_train_accuracy and train_loss < best_train_loss:
            timestamp = str(int(time.time()))
            model.save_model(
                os.path.join(model_dir, model_name, "models"), model_name +
                "_{}_{}_{}".format(timestamp, train_loss, train_accuracy))
            best_train_accuracy = train_accuracy
            best_train_loss = train_loss

            toleration_cnt = 0

            info["best_train_loss"] = str(best_train_loss)
            info["best_train_accuracy"] = str(best_train_accuracy)

            logger_utils.json_info(
                os.path.join(model_dir, model_name, "logs", "info.json"), info)
            logger.info(
                "epoch\t{}\tbest_train\tloss\t{}\tbest_accuracy\t{}".format(
                    epoch, train_loss, train_accuracy))
        else:
            toleration_cnt += 1

        if toleration_cnt >= toleration:
            break