コード例 #1
0
ファイル: run.py プロジェクト: yihenglu/Text_Matching
def main(config):

    if not os.path.exists(config.model_dir):
        os.makedirs(config.model_dir)

    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)

    print("\t \t \t the model name is {}".format(config.model_name))
    device, n_gpu = get_device()

    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(config.seed)
        torch.backends.cudnn.deterministic = True  # cudnn 使用确定性算法,保证每次结果一样
    '''数据准备'''
    text_field = data.Field(tokenize='spacy',
                            lower=True,
                            include_lengths=True,
                            fix_length=config.sequence_length)
    label_field = data.LabelField(dtype=torch.long)

    train_iterator, dev_iterator, test_iterator = load_sst2(
        config.data_path, text_field, label_field, config.batch_size, device,
        config.glove_word_file, config.cache_path)
    '''词向量准备'''
    pretrained_embeddings = text_field.vocab.vectors

    model_file = config.model_dir + 'model1.pt'
    '''模型准备'''
    if config.model_name == 'TextRNN':
        from TextRNN import TextRNN
        model = TextRNN.TextRNN(config.glove_word_dim, config.output_dim,
                                config.hidden_size, config.num_layers,
                                config.bidirectional, config.dropout,
                                pretrained_embeddings)

    optimizer = optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()

    if config.do_train:
        train(config.epoch_num, model, train_iterator, dev_iterator, optimizer,
              criterion, ['0', '1'], model_file, config.log_dir,
              config.print_step, 'word')

    model.load_state_dict(torch.load(model_file))

    test_loss, test_acc, test_report = evaluate(model, test_iterator,
                                                criterion, ['0', '1'], 'word')
    print("-------------- Test -------------")
    print("\t Loss: {} | Acc: {} | Macro avg F1: {} | Weighted avg F1: {}".
          format(test_loss, test_acc, test_report['macro avg']['f1-score'],
                 test_report['weighted avg']['f1-score']))
コード例 #2
0
ファイル: run_SST.py プロジェクト: selene009/NLP-Pytorch
def main(config):
    print("\t \t \t the model name is {}".format(config.model_name))
    device, n_gpu = get_device()

    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(config.seed)
        torch.backends.cudnn.deterministic = True  # cudnn 使用确定性算法,保证每次结果一样
    """ sst2 数据准备 """
    text_field = data.Field(tokenize='spacy',
                            lower=True,
                            include_lengths=True,
                            fix_length=config.sequence_length)
    label_field = data.LabelField(dtype=torch.long)

    train_iterator, dev_iterator, test_iterator = load_sst2(
        config.data_path, text_field, label_field, config.batch_size, device,
        config.glove_word_file)
    """ 词向量准备 """
    pretrained_embeddings = text_field.vocab.vectors
    """ 模型准备 """
    if config.model_name == "TextCNN":
        filter_sizes = [int(val) for val in config.filter_sizes.split()]
        model = TextCNN.TextCNN(config.glove_word_dim, config.filter_num,
                                filter_sizes, config.output_dim,
                                config.dropout, pretrained_embeddings)
    elif config.model_name == "TextRNN":
        model = TextRNN.TextRNN(config.glove_word_dim, config.output_dim,
                                config.hidden_size, config.num_layers,
                                config.bidirectional, config.dropout,
                                pretrained_embeddings)
    elif config.model_name == "LSTMATT":
        model = LSTMATT.LSTMATT(config.glove_word_dim, config.output_dim,
                                config.hidden_size, config.num_layers,
                                config.bidirectional, config.dropout,
                                pretrained_embeddings)
    elif config.model_name == 'TextRCNN':
        model = TextRCNN.TextRCNN(config.glove_word_dim, config.output_dim,
                                  config.hidden_size, config.num_layers,
                                  config.bidirectional, config.dropout,
                                  pretrained_embeddings)

    optimizer = optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()

    model = model.to(device)
    criterion = criterion.to(device)

    best_dev_loss = float('inf')
    for epoch in range(config.epoch_num):
        start_time = time.time()

        train_loss, train_acc, train_report = train(model, train_iterator,
                                                    optimizer, criterion,
                                                    config.output_dim)
        dev_loss, dev_acc, dev_report = evaluate(model, dev_iterator,
                                                 criterion, config.output_dim)

        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if dev_loss < best_dev_loss:
            best_dev_loss = dev_loss
            torch.save(model.state_dict(), 'tut2-model.pt')

        print(
            f'---------------- Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s ----------'
        )
        print("-------------- Train -------------")
        print(f'\t \t Loss: {train_loss:.3f} |  Acc: {train_acc*100: .2f} %')
        print('{}'.format(train_report))
        print("-------------- Dev -------------")
        print(f'\t \t Loss: {dev_loss: .3f} | Acc: {dev_acc*100: .2f} %')
        print('{}'.format(dev_report))

    model.load_state_dict(torch.load('tut2-model.pt'))

    test_loss, test_acc, test_report = evaluate(model, test_iterator,
                                                criterion, config.output_dim)
    print("-------------- Test -------------")
    print(f'\t \t Loss: {test_loss: .3f} | Acc: {test_acc*100: .2f} %')
    print('{}'.format(test_report))