示例#1
0
def main(config):
    train_loader = get_data_loader(config)

    solver = Solver(config, train_loader=train_loader, is_train=True)
    print(config)
    print(f'\nTotal data size: {solver.total_data_size}\n')

    solver.build()
    solver.train()
示例#2
0
def train(config, device, RS='Supervised'):
    # Init tokenizer.
    tokenizer = Tokenizer(config.temp_dir, config.jieba_dict_file,
                          config.remove_stopwords, config.stopwords_file,
                          config.ivr)
    # Init feature index.
    feature_index = FeatureIndex(config, tokenizer=tokenizer)
    file_list = [config.labeled_file]
    if config.extra_train_file is not None:
        file_list.append(config.extra_train_file)
    if config.valid_file is not None:
        file_list.append(config.valid_file)
    feature_index.build_index(file_list)
    # Preprocess data.
    pre_process = PreProcess(config)
    train_data_dir, valid_data_dir, final_train_file, final_valid_file = pre_process.train_preprocess(
    )
    # Get PyTorch dataset.
    train_dataset = MixnetDataset(config, train_data_dir, feature_index,
                                  tokenizer)
    valid_dataset = MixnetDataset(config, valid_data_dir, feature_index,
                                  tokenizer, True)
    # Get NER model if necessary and compatible.
    need_ner = False
    for (feature, feature_config) in config.feature_config_dict.items():
        need_ner = need_ner or ("text" in feature_config.get(
            "type", "") and feature_config.get("seg_type", "word") == "char"
                                and feature_config.get("ner", False))
    if need_ner:
        logger.info("Enable NER, loading NER model...")
        # Use predict mode since we cannot train it without tag information.
        ner_model = NERModel(device, "predict")
    else:
        logger.info("Disable NER.")
        ner_model = None
    # Get PyTorch data loader.
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=1,
                                   shuffle=False,
                                   num_workers=config.read_workers)
    valid_data_loader = DataLoader(valid_dataset,
                                   batch_size=1,
                                   shuffle=False,
                                   num_workers=config.read_workers)
    # Init model.
    model = MixNet(config.model_config_dict,
                   config.output_config_dict,
                   feature_index.feature_info_dict,
                   feature_index.label_info_dict,
                   ner_model=ner_model)
    # Train model.
    solver = Solver(config, train_data_loader, valid_data_loader,
                    feature_index, model, device, RS)
    solver.build()
    solver.train()
示例#3
0
from configs import get_config
from solver import Solver
from data_loader import get_loader


if __name__ == '__main__':
    
    config = get_config(mode='train')
    test_config = get_config(mode='test')
    
    print(config)
    print(test_config)
    print('split_index:', config.split_index)
    
    train_loader = get_loader(config.mode, config.split_index)
    test_loader = get_loader(test_config.mode, test_config.split_index)
    solver = Solver(config, train_loader, test_loader)

    solver.build()
    solver.evaluate(-1) # evaluates the summaries generated using the initial random weights of the network
    solver.train()
示例#4
0
from solver import Solver
from data_loader import get_loader
from configs import get_config

if __name__ == '__main__':
    config = get_config()

    data_loader = get_loader(batch_size=config.batch_size,
                             max_size=config.vocab_size,
                             is_train=False,
                             data_dir=config.data_dir)

    solver = Solver(config, data_loader)
    solver.build(is_train=False)
    solver.eval()
示例#5
0
from solver import Solver
from data_loader import get_loader
from configs import get_config

import os, sys
import mlflow

if __name__ == '__main__':
    config = get_config()
    print(config)

    data_loader = get_loader(batch_size=config.batch_size, max_size=config.vocab_size, is_train=True, data_dir=config.data_dir)

    solver = Solver(config, data_loader)
    solver.build(is_train=True)
    solver.train()