示例#1
0
def test_GraphAttention():
    # DATA_PATH = './data/dataset_eval'
    DATA_PATH = './nbc'
    DICT_PATH = './checkpoint/dict_20000.pkl'
    EMBEDDING_PATH_RANDOM = './model/save_embedding_97and3.ckpt'
    GPU_NUM = 0
    model_epoch = 7
    MODEL_PATH = f"./checkpoint/syntax_Attn/model-{model_epoch}.ckpt"
    BATCH_SIZE = 200

    torch.cuda.set_device(GPU_NUM)

    embed = get_word_embed().cuda()
    embed_flag = get_flag_embed().cuda()
    vocab = get_vocab()

    config = SyntaxConfig()
    model = BiLSTMSyntax(config)
    model.load(MODEL_PATH)
    model.cuda()

    trainset = SyntaxDataset(vocab=vocab, data_path=DATA_PATH)
    trainloader = DataLoader(dataset=trainset,
                             batch_size=config.batch_size,
                             collate_fn=syntax_fn,
                             num_workers=10,
                             pin_memory=True,
                             shuffle=False)

    correct_num = 0
    all_num = 0
    recall_correct = 0
    recall_all = 0

    model.eval()
    for index, (src, trg, labels, tags) in enumerate(tqdm(trainloader)):
        print(index)
        batch_size = src.shape[0]
        src = embed(src.cuda())
        trg = embed(trg.cuda())
        labels = labels.cuda()
        tags = tags.cuda()

        flag4encoder = torch.zeros(src.shape[0], src.shape[1], 3).cuda()
        src = torch.cat([src, flag4encoder], dim=2)

        encoder_output, hidden, syntax_hidden = model.step_encoding(
            src, tags)  # get the hidden state of the encoder

        output_labels = []
        input_flag = [[2] for j in range(trg.shape[0])]  # 第一位的标志位
        input_flag = torch.Tensor(input_flag).long().cuda()

        output_labels = []
        out_flag = 2  # 一开始的output flag

        hidden_syntax = syntax_hidden
        for index in range(trg.shape[1]):
            flag4decoder = embed_flag(input_flag)
            trg_step = trg[:, index, :].view(batch_size, 1, -1)
            tags_step = tags[:, index].view(batch_size, -1)

            trg_step = torch.cat([trg_step, flag4decoder], dim=-1)

            out, hidden_syntax, hidden = model.test_decoding(
                trg_step,
                tags_step,
                flag4decoder,
                hidden_syntax=hidden_syntax,
                hidden=hidden)

            input_flag = torch.max(out, 2)[1]
            output_labels.append(input_flag)

        output_labels = torch.cat(output_labels, dim=1)
        labels = labels.squeeze()

        mask_matrix = labels < 2
        predict_labels = torch.masked_select(output_labels, mask_matrix)
        ground_truth = torch.masked_select(labels, mask_matrix)

        correct_num += torch.sum(predict_labels == ground_truth).item()
        recall_correct += torch.sum(predict_labels & ground_truth).item()
        recall_all += torch.sum(ground_truth).item()
        all_num += len(ground_truth)

        P = correct_num / all_num
        R = recall_correct / recall_all
        F1 = 2 * P * R / (P + R)
        print('Precision is {}'.format(P))
        print('Recall is {}'.format(R))
        print('F1 is {} \n'.format(F1))

    P = correct_num / all_num
    R = recall_correct / recall_all
    F1 = 2 * P * R / (P + R)

    print('Finally', BATCH_SIZE)
    print('\tPrecision is {}'.format(P))
    print('\tRecall is {}'.format(R))
    print('\tF1 is {}'.format(F1))
    print(correct_num, recall_correct)
    return P, R, F1
示例#2
0
def train(reload_dataset=False, pretrain_model_path=None, optim_fu='adam'):
    write = SummaryWriter()

    vis = visdom.Visdom(env="syntax_compression")
    viz = Visdom_line(vis=vis, win="syntax_geted_lstm")

    # 一些配置
    DATA_DIR = '../data/train_pairs'
    DICT_PATH = '../checkpoint/dict_20000.pkl'
    EMBEDDING_PATH_RANDOM = '../model/save_embedding_97and3.ckpt'

    SAVE_EMBEDDING = False
    RELOAD_DATASET = reload_dataset

    SAVE_DATASET_OBJ = '../data/dataset.pkl'
    SAVE_MODEL_PATH = './checkpoint/syntax_gate_lstm/'

    PRINT_STEP = 10
    SAVE_STEP = 1
    GPU_NUM = 1

    torch.manual_seed(2)
    torch.cuda.set_device(GPU_NUM)

    model = SyntaxLSTM(100, 100, 10)
    model.cuda()

    if os.path.exists(SAVE_MODEL_PATH) is False:
        os.makedirs(SAVE_MODEL_PATH)

    # 读取embedding
    embed = get_word_embed().cuda()
    embed_flag = get_flag_embed().cuda()

    vocab = get_vocab()

    criterion = nn.CrossEntropyLoss(ignore_index=2)

    # if pretrain_model_path is not None:
    #     print('Loading the pre train model', pretrain_model_path)
    #     model.load(pretrain_model_path)
    #     model.embed.weight.requires_grad = True
    #     parameters = model.parameters()
    #     optimizer = optim.SGD(parameters, lr=0.000001)
    # else:
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(parameters, lr=0.0001)

    trainset = SyntaxDataset(vocab=vocab, reverse_src=True)
    trainloader = DataLoader(dataset=trainset,
                             batch_size=200,
                             collate_fn=syntax_fn,
                             pin_memory=True,
                             num_workers=5,
                             shuffle=True)

    global_step = 0
    loss_print = 0
    step_print = 0
    PINRT_STEP = 10
    for epoch in range(500):
        epoch_loss = 0
        for index, (src, trg, labels, tags) in enumerate(tqdm(trainloader)):
            src = embed(src.cuda())
            trg = embed(trg.cuda())
            tags = tags.cuda()

            flag4encoder = torch.zeros(src.shape[0], src.shape[1], 3).cuda()
            src = torch.cat([src, flag4encoder], dim=2)

            flag4decoder = torch.zeros([labels.shape[0],
                                        1]).long()  # decoder最前面插入一个起始全0
            flag4decoder = torch.cat([flag4decoder, labels[:, :-1]],
                                     dim=1).cuda()
            flag4decoder = embed_flag(flag4decoder)

            trg = torch.cat([trg, flag4decoder], dim=2)  # 插入最后三位标志位
            labels = labels.cuda()

            out = model(src, trg, tags)
            out = out.view(-1, 2)
            labels = labels.view(-1)
            loss = criterion(out, labels)
            epoch_loss += loss.item()
            loss_print += loss.item()
            print(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_step += 1

            if global_step % PRINT_STEP == 0:
                write.add_scalar('loss', loss_print / PRINT_STEP, step_print)
                step_print += 1
                loss_print = 0

        model.save(SAVE_MODEL_PATH + 'model-' + str(epoch) + '.ckpt')
        write.add_scalar('epoch_loss', epoch_loss, epoch)
示例#3
0
def test_syntax_lstm():
    DATA_PATH = '../data/dataset_eval'
    DICT_PATH = '../checkpoint/dict_20000.pkl'
    EMBEDDING_PATH_RANDOM = '../model/save_embedding_97and3.ckpt'
    GPU_NUM = 0
    model_epoch = 20
    MODEL_PATH = f"./checkpoint/syntax_gate_lstm/model-{model_epoch}.ckpt"

    torch.cuda.set_device(GPU_NUM)

    embed = get_word_embed().cuda()
    embed_flag = get_flag_embed().cuda()
    vocab = get_vocab()

    model = SyntaxLSTM(100, 100, 10)
    # model.load(MODEL_PATH)
    model.cuda()

    testset = SyntaxDataset(vocab=vocab, data_path='../data/dataset_eval')
    testloader = DataLoader(dataset=testset,
                            batch_size=200,
                            collate_fn=syntax_fn,
                            num_workers=2,
                            pin_memory=True,
                            shuffle=False)

    correct_num = 0
    all_num = 0
    recall_correct = 0
    recall_all = 0

    model.eval()
    for index, (src, trg, labels, tags) in enumerate(tqdm(testloader)):
        print(index)
        batch_size = src.shape[0]
        src = embed(src.cuda())
        trg = embed(trg.cuda())
        labels = labels.cuda()
        tags = tags.cuda()

        # finally get the source
        flag4encoder = torch.zeros(src.shape[0], src.shape[1], 3).cuda()
        src = torch.cat([src, flag4encoder], dim=2)

        output_labels = model.testing(src, trg, tags, embed_flag)

        # print(output_labels)

        labels = labels.squeeze()
        mask_matrix = labels < 2
        predict_labels = torch.masked_select(output_labels, mask_matrix)
        ground_truth = torch.masked_select(labels, mask_matrix)

        correct_num += torch.sum(predict_labels == ground_truth).item()
        recall_correct += torch.sum(predict_labels & ground_truth).item()
        recall_all += torch.sum(ground_truth).item()
        all_num += len(ground_truth)

        P = correct_num / all_num
        R = recall_correct / recall_all
        F1 = 2 * P * R / (P + R)
        print('Precision is {}'.format(P))
        print('Recall is {}'.format(R))
        print('F1 is {} \n'.format(F1))

    P = correct_num / all_num
    R = recall_correct / recall_all
    F1 = 2 * P * R / (P + R)

    print('Finally')
    print('\tPrecision is {}'.format(P))
    print('\tRecall is {}'.format(R))
    print('\tF1 is {}'.format(F1))
    print(correct_num, recall_correct)
    return P, R, F1