Beispiel #1
0
def run_trainer(config):
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    run_name_format = (
        "d_model={d_model}-"
        "layers_count={layers_count}-"
        "heads_count={heads_count}-"
        "pe={positional_encoding}-"
        "optimizer={optimizer}-"
        "{timestamp}"
    )

    run_name = run_name_format.format(**config, timestamp=datetime.now().strftime("%Y_%m_%d_%H_%M_%S"))

    logger = get_logger(run_name, save_log=config['save_log'])
    logger.info(f'Run name : {run_name}')
    logger.info(config)

    logger.info('Constructing dictionaries...')
    source_dictionary = IndexDictionary.load(config['data_dir'], mode='source', vocabulary_size=config['vocabulary_size'])
    target_dictionary = IndexDictionary.load(config['data_dir'], mode='target', vocabulary_size=config['vocabulary_size'])
    logger.info(f'Source dictionary vocabulary : {source_dictionary.vocabulary_size} tokens')
    logger.info(f'Target dictionary vocabulary : {target_dictionary.vocabulary_size} tokens')

    logger.info('Building model...')
    model = build_model(config, source_dictionary.vocabulary_size, target_dictionary.vocabulary_size)

    logger.info(model)
    logger.info('Encoder : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.encoder.parameters()])))
    logger.info('Decoder : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.decoder.parameters()])))
    logger.info('Total : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.parameters()])))

    logger.info('Loading datasets...')
    train_dataset = IndexedInputTargetTranslationDataset(
        data_dir=config['data_dir'],
        phase='train',
        vocabulary_size=config['vocabulary_size'],
        limit=config['dataset_limit'])

    val_dataset = IndexedInputTargetTranslationDataset(
        data_dir=config['data_dir'],
        phase='val',
        vocabulary_size=config['vocabulary_size'],
        limit=config['dataset_limit'])

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=input_target_collate_fn)

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        collate_fn=input_target_collate_fn)

    if config['label_smoothing'] > 0.0:
        loss_function = LabelSmoothingLoss(label_smoothing=config['label_smoothing'],
                                           vocabulary_size=target_dictionary.vocabulary_size)
    else:
        loss_function = TokenCrossEntropyLoss()

    accuracy_function = AccuracyMetric()

    if config['optimizer'] == 'Noam':
        optimizer = NoamOptimizer(model.parameters(), d_model=config['d_model'])
    elif config['optimizer'] == 'Adam':
        optimizer = Adam(model.parameters(), lr=config['lr'])
    else:
        raise NotImplementedError()

    logger.info('Start training...')
    trainer = EpochSeq2SeqTrainer(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        loss_function=loss_function,
        metric_function=accuracy_function,
        optimizer=optimizer,
        logger=logger,
        run_name=run_name,
        save_config=config['save_config'],
        save_checkpoint=config['save_checkpoint'],
        config=config
    )

    trainer.run(config['epochs'])

    return trainer
Beispiel #2
0
def prepare(data_name, range_name):
    set_dn_rn(data_name, range_name)  # 设置test on the fly的参数
    # save_data_dir = args.save_data_dir + "-" + args.postfix
    print(f"Data name: {data_name}\tRange name: {range_name}")
    save_data_dir = args.save_data_dir.format(data_name=data_name,
                                              range_name=range_name)
    train_source = args.train_source.format(data_name=data_name,
                                            range_name=range_name)
    train_target = args.train_target.format(data_name=data_name,
                                            range_name=range_name)
    val_source = args.val_source.format(data_name=data_name,
                                        range_name=range_name)
    val_target = args.val_target.format(data_name=data_name,
                                        range_name=range_name)
    if not os.path.isfile(train_source):
        print(f"File dir: {train_source} does not exist.")
        return
    else:
        print(f"Preparing: {train_source}.")

    TranslationDataset.prepare(train_source, train_target, val_source,
                               val_target, save_data_dir)
    translation_dataset = TranslationDataset(save_data_dir, 'train')
    translation_dataset_on_the_fly = TranslationDatasetOnTheFly('train')

    # print(save_data_dir)
    # print(translation_dataset[0])
    # print(translation_dataset_on_the_fly[0])
    assert translation_dataset[0] == translation_dataset_on_the_fly[0]
    tokenized_dataset = TokenizedTranslationDataset(save_data_dir, 'train')

    if args.share_dictionary:
        source_generator = shared_tokens_generator(tokenized_dataset)
        source_dictionary = IndexDictionary(source_generator, mode='source')
        target_generator = shared_tokens_generator(tokenized_dataset)
        target_dictionary = IndexDictionary(target_generator, mode='target')
    elif not args.ind_dict:  # 使用全集
        print("Using pre-defined dict...")
        src_dict = [
            '<PAD>', '<CLS>', '<EOS>', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
            'i', 'j', '0', '1', '!', '&', '|', 'I', 'X', 'F', 'G', 'U', 'W',
            'R'
        ]
        tgt_dict = [
            '<PAD>', '<SOS>', '<EOS>', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
            'i', 'j', '0', '1', '&', '|', '(', ')', ';', '{', '}'
        ]
        source_dictionary = IndexDictionary(None,
                                            mode='source',
                                            exist_vocab=src_dict)
        target_dictionary = IndexDictionary(None,
                                            mode='target',
                                            exist_vocab=tgt_dict)
    else:
        source_generator = source_tokens_generator(tokenized_dataset)
        source_dictionary = IndexDictionary(source_generator, mode='source')
        target_generator = target_tokens_generator(tokenized_dataset)
        target_dictionary = IndexDictionary(target_generator, mode='target')

    source_dictionary.save(save_data_dir)
    target_dictionary.save(save_data_dir)
    source_dictionary = IndexDictionary.load(save_data_dir, mode='source')
    target_dictionary = IndexDictionary.load(save_data_dir, mode='target')

    IndexedInputTargetTranslationDataset.prepare(save_data_dir,
                                                 source_dictionary,
                                                 target_dictionary)
    indexed_translation_dataset = IndexedInputTargetTranslationDataset(
        save_data_dir, 'train')
    indexed_translation_dataset_on_the_fly = IndexedInputTargetTranslationDatasetOnTheFly(
        'train', source_dictionary, target_dictionary)
    # print(indexed_translation_dataset[0])
    # print(indexed_translation_dataset_on_the_fly[0])
    assert indexed_translation_dataset[
        0] == indexed_translation_dataset_on_the_fly[0]
    print('Done datasets preparation.')
from argparse import ArgumentParser
import json

parser = ArgumentParser(description='Predict translation')
parser.add_argument('--source', type=str)
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--checkpoint', type=str)
parser.add_argument('--num_candidates', type=int, default=1)

args = parser.parse_args()
with open(args.config) as f:
    config = json.load(f)

print('Constructing dictionaries...')
source_dictionary = IndexDictionary.load(config['data_dir'], mode='source', vocabulary_size=config['vocabulary_size'])
target_dictionary = IndexDictionary.load(config['data_dir'], mode='target', vocabulary_size=config['vocabulary_size'])

print('Building model...')
model = build_model(config, source_dictionary.vocabulary_size, target_dictionary.vocabulary_size)

predictor = Predictor(
    preprocess=IndexedInputTargetTranslationDataset.preprocess(source_dictionary),
    postprocess=lambda x: ' '.join([token for token in target_dictionary.tokenify_indexes(x) if token != '<EndSent>']),
    model=model,
    checkpoint_filepath=args.checkpoint
)

for index, candidate in enumerate(predictor.predict_one(args.source, num_candidates=args.num_candidates)):
    print(f'Candidate {index} : {candidate}')
if args.share_dictionary:
    source_generator = shared_tokens_generator(tokenized_dataset)
    source_dictionary = IndexDictionary(source_generator, mode='source')
    target_generator = shared_tokens_generator(tokenized_dataset)
    target_dictionary = IndexDictionary(target_generator, mode='target')

    source_dictionary.save(args.save_data_dir)
    target_dictionary.save(args.save_data_dir)
else:
    source_generator = source_tokens_generator(tokenized_dataset)
    source_dictionary = IndexDictionary(source_generator, mode='source')
    target_generator = target_tokens_generator(tokenized_dataset)
    target_dictionary = IndexDictionary(target_generator, mode='target')

    source_dictionary.save(args.save_data_dir)
    target_dictionary.save(args.save_data_dir)

source_dictionary = IndexDictionary.load(args.save_data_dir, mode='source')
target_dictionary = IndexDictionary.load(args.save_data_dir, mode='target')

IndexedInputTargetTranslationDataset.prepare(args.save_data_dir,
                                             source_dictionary,
                                             target_dictionary)
indexed_translation_dataset = IndexedInputTargetTranslationDataset(
    args.save_data_dir, 'train')
indexed_translation_dataset_on_the_fly = IndexedInputTargetTranslationDatasetOnTheFly(
    'train', source_dictionary, target_dictionary)
assert indexed_translation_dataset[
    0] == indexed_translation_dataset_on_the_fly[0]

print('Done datasets preparation.')
Beispiel #5
0
def run_trainer(config):
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    run_name_format = (
        f"data={data_name}-"
        f"range={range_name}-"
        "d_model={d_model}-"
        "layers_count={nlayers}-"
        "heads_count={nhead}-"
        "FC_size={nhid}-"
        "lr={lr}-"
        "{timestamp}"
    )

    run_name = run_name_format.format(**config, timestamp=datetime.now().strftime("%Y_%m_%d_%H_%M_%S"))

    logger = get_logger(run_name, save_log=config['save_log'])
    logger.info(f'Run name : {run_name}')
    logger.info(config)

    data_dir = config['data_dir'] + "-" + data_name + "-" + range_name
    logger.info(f'Constructing dictionaries from {data_dir}...')
    source_dictionary = IndexDictionary.load(data_dir, mode='source')
    target_dictionary = IndexDictionary.load(data_dir, mode='target')
    logger.info(f'Source dictionary vocabulary : {source_dictionary.vocabulary_size} tokens')
    logger.info(f'Target dictionary vocabulary : {target_dictionary.vocabulary_size} tokens')

    logger.info('Building model...')
    model = TransformerModel(source_dictionary.vocabulary_size, target_dictionary.vocabulary_size,
                             d_model=config['d_model'],
                             nhead=config['nhead'],
                             nhid=config['nhid'],
                             nlayers=config['nlayers'])
    logger.info(model)
    logger.info('Encoder : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.transformer_encoder.parameters()])))
    logger.info('Decoder : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.transformer_decoder.parameters()])))
    logger.info('Total : {parameters_count} parameters'.format(parameters_count=sum([p.nelement() for p in model.parameters()])))

    logger.info('Loading datasets...')
    train_dataset = IndexedInputTargetTranslationDataset(
        data_dir=data_dir,
        phase='train')

    val_dataset = IndexedInputTargetTranslationDataset(
        data_dir=data_dir,
        phase='val')

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=input_target_collate_fn,
        num_workers=5)

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        collate_fn=input_target_collate_fn,
        num_workers=5)

    loss_function = TokenCrossEntropyLoss()
    accuracy_function = AccuracyMetric()
    optimizer = Adam(model.parameters(), lr=config['lr'])

    logger.info('Start training...')
    trainer = EpochSeq2SeqTrainer(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        loss_function=loss_function,
        metric_function=accuracy_function,
        optimizer=optimizer,
        logger=logger,
        run_name=run_name,
        save_config=config['save_config'],
        save_checkpoint=config['save_checkpoint'],
        config=config,
        iter_num=args.iter_num
    )

    trainer.run(config['epochs'])

    return trainer
Beispiel #6
0
def predict(dn, rn):
    dir_name_format = "../data/{dn}-{rn}-raw"
    dir_name = dir_name_format.format(dn=dn, rn=rn)
    input_path = os.path.join(dir_name, "src-test.txt")
    if not os.path.isfile(input_path):
        print(f"File: {input_path} not exist.")
        return

    output_filename = f"prediction-{dn}-{rn}.txt"
    output_path = os.path.join(outputDir, output_filename)
    if os.path.isfile(output_path):
        print(f"File {output_path} already exists.")
        return

    # 作用:将src进行index
    preprocess = IndexedInputTargetTranslationDataset.preprocess(source_dictionary)
    # 作用:将输出逆index为句子
    postprocess = lambda x: ''.join(
        [token for token in target_dictionary.tokenize_indexes(x) if token != END_TOKEN and token != START_TOKEN and token != PAD_TOKEN])
    device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() and not args.no_cuda else 'cpu')

    print('Building model...')
    model = TransformerModel(source_dictionary.vocabulary_size, target_dictionary.vocabulary_size,
                             config['d_model'],
                             config['nhead'],
                             config['nhid'],
                             config['nlayers'])
    model.eval()
    checkpoint_filepath = checkpoint_path
    checkpoint = torch.load(checkpoint_filepath, map_location='cpu')
    model.load_state_dict(checkpoint)
    translator = Translator(
        model=model,
        beam_size=args.beam_size,
        max_seq_len=args.max_seq_len,
        trg_bos_idx=target_dictionary.token_to_index(START_TOKEN),
        trg_eos_idx=target_dictionary.token_to_index(END_TOKEN)
    ).to(device)

    from utils.pipe import PAD_INDEX
    def pad_src(batch):
        sources_lengths = [len(sources) for sources in batch]
        sources_max_length = max(sources_lengths)
        sources_padded = [sources + [PAD_INDEX] * (sources_max_length - len(sources)) for sources in batch]
        sources_tensor = torch.tensor(sources_padded)
        return sources_tensor
    def process(seq):
        seq = seq.strip()
        def is_proof(name):
            return name.count("balance") > 0 or name.count("one") > 0
        if is_proof(data_name) and not is_proof(dn):
            seq += ",$,1"
            global is_proof_process
            if is_proof_process:
                print("processing")
                is_proof_process = False
        return seq

    batch_size = args.bs
    print(f"Output to {output_path}:")
    with open(output_path, 'w', encoding='utf-8') as outFile:
        with open(input_path, 'r', encoding='utf-8') as inFile:
            seqs = []
            for seq in tqdm(inFile):
                seq = process(seq)
                src_seq = preprocess(seq)
                seqs.append(src_seq)
                if len(seqs) >= batch_size:
                    pred_seq = translator.translate_sentence(pad_src(seqs).to(device))
                    pred_line = [postprocess(pred) for pred in pred_seq]
                    # print(pred_line)
                    outFile.writelines([p.strip() + '\n' for p in pred_line])
                    seqs.clear()
                # endif
            # endfor
            if seqs:    # last batch
                pred_seq = translator.translate_sentence(pad_src(seqs).to(device))
                pred_line = [postprocess(pred).replace(START_TOKEN, '').replace(END_TOKEN, '') for pred in pred_seq]
                # print(pred_line)
                outFile.writelines([p.strip() + '\n' for p in pred_line])
                seqs.clear()
        # endwith
    # endwith
    print(f'[Info] {input_path} Finished.')