Ejemplo n.º 1
0
def main(args):
    model = RCNN(vocab_size=args.vocab_size,
                 embedding_dim=args.embedding_dim,
                 hidden_size=args.hidden_size,
                 hidden_size_linear=args.hidden_size_linear,
                 class_num=args.class_num,
                 dropout=args.dropout).to(args.device)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model, dim=0)

    train_texts, train_labels = read_file(args.train_file_path)
    word2idx = build_dictionary(train_texts, vocab_size=args.vocab_size)
    logger.info('Dictionary Finished!')

    full_dataset = CustomTextDataset(train_texts, train_labels, word2idx)
    num_train_data = len(full_dataset) - args.num_val_data
    train_dataset, val_dataset = random_split(
        full_dataset, [num_train_data, args.num_val_data])
    train_dataloader = DataLoader(dataset=train_dataset,
                                  collate_fn=lambda x: collate_fn(x, args),
                                  batch_size=args.batch_size,
                                  shuffle=True)

    valid_dataloader = DataLoader(dataset=val_dataset,
                                  collate_fn=lambda x: collate_fn(x, args),
                                  batch_size=args.batch_size,
                                  shuffle=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    train(model, optimizer, train_dataloader, valid_dataloader, args)
    logger.info('******************** Train Finished ********************')

    # Test
    if args.test_set:
        test_texts, test_labels = read_file(args.test_file_path)
        test_dataset = CustomTextDataset(test_texts, test_labels, word2idx)
        test_dataloader = DataLoader(dataset=test_dataset,
                                     collate_fn=lambda x: collate_fn(x, args),
                                     batch_size=args.batch_size,
                                     shuffle=True)

        model.load_state_dict(
            torch.load(os.path.join(args.model_save_path, "best.pt")))
        _, accuracy, precision, recall, f1, cm = evaluate(
            model, test_dataloader, args)
        logger.info('-' * 50)
        logger.info(
            f'|* TEST SET *| |ACC| {accuracy:>.4f} |PRECISION| {precision:>.4f} |RECALL| {recall:>.4f} |F1| {f1:>.4f}'
        )
        logger.info('-' * 50)
        logger.info('---------------- CONFUSION MATRIX ----------------')
        for i in range(len(cm)):
            logger.info(cm[i])
        logger.info('--------------------------------------------------')
Ejemplo n.º 2
0
def load_model(test_arguments):
    rcnn = RCNN(test_arguments.pos_loss_method, test_arguments.loss_weight_lambda).cuda()
    rcnn.load_state_dict(t.load(test_arguments.model_path))
    rcnn.eval()  # dropout rate = 0
    return rcnn
Ejemplo n.º 3
0
def main(args):
    acc_list = []
    f1_score_list = []
    prec_list = []
    recall_list = []
    for i in range(10):
        setup_data()
        model = RCNN(vocab_size=args.vocab_size,
                     embedding_dim=args.embedding_dim,
                     hidden_size=args.hidden_size,
                     hidden_size_linear=args.hidden_size_linear,
                     class_num=args.class_num,
                     dropout=args.dropout).to(args.device)

        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model, dim=0)

        train_texts, train_labels = read_file(args.train_file_path)
        word2idx, embedding = build_dictionary(train_texts, args.vocab_size,
                                               args.lexical, args.syntactic,
                                               args.semantic)

        logger.info('Dictionary Finished!')

        full_dataset = CustomTextDataset(train_texts, train_labels, word2idx,
                                         args)
        num_train_data = len(full_dataset) - args.num_val_data
        train_dataset, val_dataset = random_split(
            full_dataset, [num_train_data, args.num_val_data])
        train_dataloader = DataLoader(dataset=train_dataset,
                                      collate_fn=lambda x: collate_fn(x, args),
                                      batch_size=args.batch_size,
                                      shuffle=True)

        valid_dataloader = DataLoader(dataset=val_dataset,
                                      collate_fn=lambda x: collate_fn(x, args),
                                      batch_size=args.batch_size,
                                      shuffle=True)

        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        train(model, optimizer, train_dataloader, valid_dataloader, embedding,
              args)
        logger.info('******************** Train Finished ********************')

        # Test
        if args.test_set:
            test_texts, test_labels = read_file(args.test_file_path)
            test_dataset = CustomTextDataset(test_texts, test_labels, word2idx,
                                             args)
            test_dataloader = DataLoader(
                dataset=test_dataset,
                collate_fn=lambda x: collate_fn(x, args),
                batch_size=args.batch_size,
                shuffle=True)

            model.load_state_dict(
                torch.load(os.path.join(args.model_save_path, "best.pt")))
            _, accuracy, precision, recall, f1, cm = evaluate(
                model, test_dataloader, embedding, args)
            logger.info('-' * 50)
            logger.info(
                f'|* TEST SET *| |ACC| {accuracy:>.4f} |PRECISION| {precision:>.4f} |RECALL| {recall:>.4f} |F1| {f1:>.4f}'
            )
            logger.info('-' * 50)
            logger.info('---------------- CONFUSION MATRIX ----------------')
            for i in range(len(cm)):
                logger.info(cm[i])
            logger.info('--------------------------------------------------')
            acc_list.append(accuracy / 100)
            prec_list.append(precision)
            recall_list.append(recall)
            f1_score_list.append(f1)

    avg_acc = sum(acc_list) / len(acc_list)
    avg_prec = sum(prec_list) / len(prec_list)
    avg_recall = sum(recall_list) / len(recall_list)
    avg_f1_score = sum(f1_score_list) / len(f1_score_list)
    logger.info('--------------------------------------------------')
    logger.info(
        f'|* TEST SET *| |Avg ACC| {avg_acc:>.4f} |Avg PRECISION| {avg_prec:>.4f} |Avg RECALL| {avg_recall:>.4f} |Avg F1| {avg_f1_score:>.4f}'
    )
    logger.info('--------------------------------------------------')
    plot_df = pd.DataFrame({
        'x_values': range(10),
        'avg_acc': acc_list,
        'avg_prec': prec_list,
        'avg_recall': recall_list,
        'avg_f1_score': f1_score_list
    })
    plt.plot('x_values',
             'avg_acc',
             data=plot_df,
             marker='o',
             markerfacecolor='blue',
             markersize=12,
             color='skyblue',
             linewidth=4)
    plt.plot('x_values',
             'avg_prec',
             data=plot_df,
             marker='',
             color='olive',
             linewidth=2)
    plt.plot('x_values',
             'avg_recall',
             data=plot_df,
             marker='',
             color='olive',
             linewidth=2,
             linestyle='dashed')
    plt.plot('x_values',
             'avg_f1_score',
             data=plot_df,
             marker='',
             color='olive',
             linewidth=2,
             linestyle='dashed')
    plt.legend()
    fname = 'lexical-semantic-syntactic.png' if args.lexical and args.semantic and args.syntactic \
                            else 'semantic-syntactic.png' if args.semantic and args.syntactic \
                            else 'lexical-semantic.png' if args.lexical and args.semantic \
                            else 'lexical-syntactic.png'if args.lexical and args.syntactic \
                            else 'lexical.png' if args.lexical \
                            else 'syntactic.png' if args.syntactic \
                            else 'semantic.png' if args.semantic \
                            else 'plain.png'
    if not (path.exists('./images')):
        mkdir('./images')
    plt.savefig(path.join('./images', fname))
Ejemplo n.º 4
0
def load_models(config):
    # train_data = pickle.load(open(os.path.join(config.data_path, config.train_name), "rb"))
    # debug
    # if config.debug:
    #     train_data = train_data[0:30]
    # dev_data = pickle.load(open(os.path.join(config.data_path, config.dev_name), "rb"))
    # test_data = pickle.load(open(os.path.join(config.data_path, config.test_name), "rb"))
    vocabulary = pickle.load(
        open(os.path.join(config.data_path, config.vocabulary_name), "rb"))

    # load w2v data
    # weight = pickle.load(open(os.path.join(config.data_path, config.weight_name), "rb"))

    cnn = CNN(vocab_size=len(vocabulary),
              embed_dim=config.embed_dim,
              class_num=config.class_num,
              kernel_num=config.kernel_num,
              kernel_sizes=config.kernel_sizes,
              dropout=config.dropout,
              static=config.static,
              in_channels=config.in_channels)
    state_dict = torch.load(
        os.path.join(config.save_path,
                     config.ensemble_models[0])).state_dict()
    cnn.load_state_dict(state_dict)

    lstm = LSTM(vocab_size=len(vocabulary),
                embed_dim=config.embed_dim,
                output_dim=config.class_num,
                hidden_dim=config.hidden_dim,
                num_layers=config.num_layers,
                dropout=config.dropout)
    state_dict = torch.load(
        os.path.join(config.save_path,
                     config.ensemble_models[1])).state_dict()
    lstm.load_state_dict(state_dict)

    lstm_mxp = LSTM_maxpool(vocab_size=len(vocabulary),
                            embed_dim=config.embed_dim,
                            output_dim=config.class_num,
                            hidden_dim=config.hidden_dim,
                            num_layers=config.num_layers,
                            dropout=config.dropout)
    state_dict = torch.load(
        os.path.join(config.save_path,
                     config.ensemble_models[2])).state_dict()
    lstm_mxp.load_state_dict(state_dict)

    rcnn = RCNN(vocab_size=len(vocabulary),
                embed_dim=config.embed_dim,
                output_dim=config.class_num,
                hidden_dim=config.hidden_dim,
                num_layers=config.num_layers,
                dropout=config.dropout)
    state_dict = torch.load(
        os.path.join(config.save_path,
                     config.ensemble_models[3])).state_dict()
    rcnn.load_state_dict(state_dict)

    schemas = get_schemas(config.source_path)
    state_dict = torch.load(
        os.path.join(config.save_path,
                     config.ensemble_models[4])).state_dict()
    bert = BertForMultiLabelSequenceClassification.from_pretrained(
        config.bert_folder, state_dict=state_dict, num_labels=len(schemas))
    bert.load_state_dict(state_dict)

    return cnn, lstm, lstm_mxp, rcnn, bert