Beispiel #1
0
def load_data(config):

    print("-*-"*10)
    print("current data_sign: {}".format(config.data_sign))

    if config.data_sign == "conll03":
        data_processor = Conll03Processor()
    elif config.data_sign == "zh_msra":
        data_processor = MSRAProcessor()
    elif config.data_sign == "zh_onto":
        data_processor = Onto4ZhProcessor()
    elif config.data_sign == "en_onto":
        data_processor = Onto5EngProcessor()
    elif config.data_sign == "genia":
        data_processor = GeniaProcessor()
    elif config.data_sign == "ace2004":
        data_processor = ACE2004Processor()
    elif config.data_sign == "ace2005":
        data_processor = ACE2005Processor()
    elif config.data_sign == "resume":
        data_processor = ResumeZhProcessor()
    else:
        raise ValueError("Please Notice that your data_sign DO NOT exits !!!!!")

    label_list = data_processor.get_labels()
    tokenizer = BertTokenizer4Tagger.from_pretrained(config.bert_model, do_lower_case=True)

    dataset_loaders = MRCNERDataLoader(config, data_processor, label_list, tokenizer, mode="train", allow_impossible=True)
    train_dataloader = dataset_loaders.get_dataloader(data_sign="train") 
    dev_dataloader = dataset_loaders.get_dataloader(data_sign="dev")
    test_dataloader = dataset_loaders.get_dataloader(data_sign="test")
    num_train_steps = dataset_loaders.get_num_train_epochs()

    return train_dataloader, dev_dataloader, test_dataloader, num_train_steps, label_list 
def load_data(config, logger):

    logger.info("-*-"*10)
    logger.info(f"current data_sign: {config.data_sign}")

    if config.data_sign == "conll03":
        data_processor = Conll03Processor()
    elif config.data_sign == "zh_msra":
        data_processor = MSRAProcessor()
    elif config.data_sign == "zh_onto":
        data_processor = Onto4ZhProcessor()
    elif config.data_sign == "en_onto":
        data_processor = Onto5EngProcessor()
    elif config.data_sign == "genia":
        data_processor = GeniaProcessor()
    elif config.data_sign == "ace2004":
        data_processor = ACE2004Processor()
    elif config.data_sign == "ace2005":
        data_processor = ACE2005Processor()
    elif config.data_sign == "resume":
        data_processor = ResumeZhProcessor()
    elif config.data_sign == "en_wnut_20_wlp":
        data_processor = WlpWnut20Processor()
    else:
        raise ValueError("Please Notice that your data_sign DO NOT exits !!!!!")


    label_list = data_processor.get_labels()
    tokenizer = BertTokenizer4Tagger.from_pretrained(config.bert_model)

    dataset_loaders = MRCNERDataLoader(config, data_processor, label_list,
                                       tokenizer, mode="train", allow_impossible=True, ) # entity_scheme=config.entity_scheme)
    if config.debug:
        logger.info("%="*20)
        logger.info("="*10 + " DEBUG MODE " + "="*10)
        train_dataloader = dataset_loaders.get_dataloader(data_sign="dev", num_data_processor=config.num_data_processor, logger=logger)
    else:
        train_dataloader = dataset_loaders.get_dataloader(data_sign="train", num_data_processor=config.num_data_processor, logger=logger)
    dev_dataloader = dataset_loaders.get_dataloader(data_sign="dev", num_data_processor=config.num_data_processor, logger=logger)
    test_dataloader = dataset_loaders.get_dataloader(data_sign="test", num_data_processor=config.num_data_processor, logger=logger)
    train_instances = dataset_loaders.get_train_instance()
    num_train_steps = len(train_dataloader) // config.gradient_accumulation_steps * config.num_train_epochs
    per_gpu_train_batch_size = config.train_batch_size // config.n_gpu

    logger.info("****** Running Training ******")
    logger.info(f"Number of Training Data: {train_instances}")
    logger.info(f"Train Epoch {config.num_train_epochs}; Total Train Steps: {num_train_steps}; Warmup Train Steps: {config.warmup_steps}")
    logger.info(f"Per GPU Train Batch Size: {per_gpu_train_batch_size}")

    return train_dataloader, dev_dataloader, test_dataloader, num_train_steps, label_list 
def load_data(config, logger):

    logger.info("-*-"*10)
    logger.info(f"current data_sign: {config.data_sign}")

    if config.data_sign == "conll03":
        data_processor = Conll03Processor()
    elif config.data_sign == "zh_msra":
        data_processor = MSRAProcessor()
    elif config.data_sign == "zh_onto":
        data_processor = Onto4ZhProcessor()
    elif config.data_sign == "en_onto":
        data_processor = Onto5EngProcessor()
    elif config.data_sign == "genia":
        data_processor = GeniaProcessor()
    elif config.data_sign == "ace2004":
        data_processor = ACE2004Processor()
    elif config.data_sign == "ace2005":
        data_processor = ACE2005Processor()
    elif config.data_sign == "resume":
        data_processor = ResumeZhProcessor()
    else:
        raise ValueError("Please Notice that your data_sign DO NOT exits !!!!!")


    label_list = data_processor.get_labels()
    tokenizer = BertTokenizer4Tagger.from_pretrained(config.bert_model, do_lower_case=config.do_lower_case)

    dataset_loaders = MRCNERDataLoader(config, data_processor, label_list, tokenizer, mode="test", allow_impossible=True)
    test_dataloader = dataset_loaders.get_dataloader(data_sign="test", num_data_processor=config.num_data_processor, logger=logger)

    return test_dataloader, label_list