# save model in this time
    shutil.copy("./models/model_BiLSTM_1.py", "./snapshot/" + mulu)

if args.cuda is True:
    print("using cuda......")
    model = model.cuda()
print(model)

# train
print("\n cpu_count \n", mu.cpu_count())
torch.set_num_threads(args.num_threads)
if os.path.exists("./Test_Result.txt"):
    os.remove("./Test_Result.txt")
if args.CNN is True:
    print("CNN training start......")
    model_count = train_ALL_CNN.train(train_iter, dev_iter, test_iter, model,
                                      args)
elif args.BiLSTM_1 is True:
    print("BiLSTM_1 training start......")
    model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter, model,
                                       args)
print("Model_count", model_count)

resultlist = []
if os.path.exists("./Test_Result.txt"):
    file = open("./Test_Result.txt")
    for line in file.readlines():
        if line[:10] == "Evaluation":
            resultlist.append(float(line[34:41]))
    result = sorted(resultlist)
    file.close()
    file = open("./Test_Result.txt", "a")
예제 #2
0
def main():
    # if args.use_cuda is True:
    # use deterministic algorithm for cnn
    # torch.backends.cudnn.deterministic = True
    args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
    # save file
    mulu = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    args.mulu = mulu
    args.save_dir = os.path.join(args.save_dir, mulu)
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)

    # build vocab and iterator
    text_field = data.Field(lower=True)
    label_field = data.Field(sequential=False)
    if args.SST_1 is True:
        print("loading sst-1 dataset......")
        train_iter, dev_iter, test_iter = load_SST_1(
            text_field,
            label_field,
            train_path=args.train_path,
            dev_path=args.dev_path,
            test_path=args.test_path,
            device=args.gpu_device,
            repeat=False,
            shuffle=args.epochs_shuffle,
            sort=False)
    if args.SST_2 is True:
        print("loading sst-2 dataset......")
        train_iter, dev_iter, test_iter = load_SST_2(
            text_field,
            label_field,
            train_path=args.train_path,
            dev_path=args.dev_path,
            test_path=args.test_path,
            device=args.gpu_device,
            repeat=False,
            shuffle=args.epochs_shuffle,
            sort=False)
    if args.TREC is True:
        print("loading TREC dataset......")
        train_iter, test_iter = load_TREC(text_field,
                                          label_field,
                                          train_path=args.train_path,
                                          test_path=args.test_path,
                                          device=args.gpu_device,
                                          repeat=False,
                                          shuffle=args.epochs_shuffle,
                                          sort=False)

    args.embed_num = len(text_field.vocab)
    args.class_num = len(label_field.vocab) - 1
    args.PaddingID = text_field.vocab.stoi[text_field.pad_token]
    print("embed_num : {}, class_num : {}".format(args.embed_num,
                                                  args.class_num))
    print("PaddingID {}".format(args.PaddingID))
    # pretrained word embedding
    if args.word_Embedding:
        pretrain_embed = load_pretrained_emb_zeros(
            path=args.word_Embedding_Path,
            text_field_words_dict=text_field.vocab.itos,
            pad=text_field.pad_token)
        calculate_oov(path=args.word_Embedding_Path,
                      text_field_words_dict=text_field.vocab.itos,
                      pad=text_field.pad_token)
        args.pretrained_weight = pretrain_embed

    # print params
    show_params()

    # load model and start train
    if args.CNN is True:
        print("loading CNN model.....")
        # model = model_CNN.CNN_Text(args)
        model = model_SumPooling.SumPooling(args)
        # for param in model.parameters():
        #     param.requires_grad = False
        shutil.copy("./models/model_CNN.py", args.save_dir)
        print(model)
        if args.use_cuda is True:
            print("using cuda......")
            model = model.cuda()
        print("CNN training start......")
        if os.path.exists("./Test_Result.txt"):
            os.remove("./Test_Result.txt")
        if args.SST_1 is True or args.SST_2 is True:
            model_count = train_ALL_CNN.train(train_iter, dev_iter, test_iter,
                                              model, args)
        if args.TREC is True:
            model_count = train_CV.train(train_iter, test_iter, model, args)

    # calculate the best result
    cal_result()
예제 #3
0
def start_train(model, train_iter, dev_iter, test_iter):
    """
    :function:start train
    :param model:
    :param train_iter:
    :param dev_iter:
    :param test_iter:
    :return:
    """
    if config.predict is not None:
        label = train_ALL_CNN.predict(config.predict, model, config.text_field,
                                      config.label_field)
        print('\n[Text]  {}[Label] {}\n'.format(config.predict, label))
    elif config.test:
        try:
            print(test_iter)
            train_ALL_CNN.test_eval(test_iter, model, config)
        except Exception as e:
            print("\nSorry. The test dataset doesn't  exist.\n")
    else:
        print("\n cpu_count \n", mu.cpu_count())
        torch.set_num_threads(config.num_threads)
        if os.path.exists("./Test_Result.txt"):
            os.remove("./Test_Result.txt")
        if config.CNN:
            print("CNN training start......")
            model_count = train_ALL_CNN.train(train_iter, dev_iter, test_iter,
                                              model, config)
        elif config.DEEP_CNN:
            print("DEEP_CNN training start......")
            model_count = train_ALL_CNN.train(train_iter, dev_iter, test_iter,
                                              model, config)
        elif config.LSTM:
            print("LSTM training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        elif config.GRU:
            print("GRU training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        elif config.BiLSTM:
            print("BiLSTM training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        elif config.BiLSTM_1:
            print("BiLSTM_1 training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        elif config.CNN_LSTM:
            print("CNN_LSTM training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        elif config.CLSTM:
            print("CLSTM training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        elif config.CBiLSTM:
            print("CBiLSTM training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        elif config.CGRU:
            print("CGRU training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        elif config.CNN_BiLSTM:
            print("CNN_BiLSTM training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        elif config.BiGRU:
            print("BiGRU training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        elif config.CNN_BiGRU:
            print("CNN_BiGRU training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        elif config.CNN_MUI:
            print("CNN_MUI training start......")
            model_count = train_ALL_CNN.train(train_iter, dev_iter, test_iter,
                                              model, config)
        elif config.DEEP_CNN_MUI:
            print("DEEP_CNN_MUI training start......")
            model_count = train_ALL_CNN.train(train_iter, dev_iter, test_iter,
                                              model, config)
        elif config.HighWay_CNN is True:
            print("HighWay_CNN training start......")
            model_count = train_ALL_CNN.train(train_iter, dev_iter, test_iter,
                                              model, config)
        elif config.HighWay_BiLSTM_1 is True:
            print("HighWay_BiLSTM_1 training start......")
            model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter,
                                               model, config)
        print("Model_count", model_count)
        resultlist = []
        if os.path.exists("./Test_Result.txt"):
            file = open("./Test_Result.txt")
            for line in file.readlines():
                if line[:10] == "Evaluation":
                    resultlist.append(float(line[34:41]))
            result = sorted(resultlist)
            file.close()
            file = open("./Test_Result.txt", "a")
            file.write("\nThe Best Result is : " +
                       str(result[len(result) - 1]))
            file.write("\n")
            file.close()
            shutil.copy("./Test_Result.txt",
                        "./snapshot/" + config.mulu + "/Test_Result.txt")
    label = train_ALL_CNN.predict(args.predict, model, text_field, label_field)
    print('\n[Text]  {}[Label] {}\n'.format(args.predict, label))
elif args.test:
    try:
        print(test_iter)
        train_ALL_CNN.test_eval(test_iter, model, args)
    except Exception as e:
        print("\nSorry. The test dataset doesn't  exist.\n")
else:
    print("\n cpu_count \n", mu.cpu_count())
    torch.set_num_threads(args.num_threads)
    if os.path.exists("./Test_Result.txt"):
        os.remove("./Test_Result.txt")
    if args.CNN:
        print("CNN training start......")
        model_count = train_ALL_CNN.train(train_iter, dev_iter, test_iter, model, args)
    elif args.DEEP_CNN:
        print("DEEP_CNN training start......")
        model_count = train_ALL_CNN.train(train_iter, dev_iter, test_iter, model, args)
    elif args.LSTM:
        print("LSTM training start......")
        model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter, model, args)
    elif args.GRU:
        print("GRU training start......")
        model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter, model, args)
    elif args.BiLSTM:
        print("BiLSTM training start......")
        model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter, model, args)
    elif args.BiLSTM_1:
        print("BiLSTM_1 training start......")
        model_count = train_ALL_LSTM.train(train_iter, dev_iter, test_iter, model, args)