コード例 #1
0
ファイル: launch.py プロジェクト: sb-nmt-team/sb-nmt
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--params')
    parser.add_argument('--training_params')
    args = parser.parse_args()

    # todo logdir
    # todo load from logdir if possible, load params
    training_params = Trainer.get_default_hparams()
    if args.training_params:
        with open(args.training_params, encoding='utf-8') as fin:
            training_params = training_params.parse_dict(json.load(fin))

    logger_config = LOGGING_BASE

    model_name = os.path.join(training_params.prefix,
                              training_params.model_name)
    model_folder = os.path.join(TRAINED_MODELS_FOLDER, model_name)

    latest_folder = find_latest_experiment(model_folder) if os.path.exists(
        model_folder) else None
    latest_folder = None if training_params.force_override else latest_folder
    new_folder = create_new_experiment(model_folder, latest_folder)

    logger_config['handlers']['debug']['filename'] = os.path.join(
        new_folder, 'debug_logs')
    logger_config['handlers']['stdout']['filename'] = os.path.join(
        new_folder, 'stdout_logs')
    logging.config.dictConfig(logger_config)

    copy2(args.training_params, new_folder)
    copy2(args.params, new_folder)

    logger.info("Using python binary at {}".format(sys.executable))

    os.environ['CUDA_VISIBLE_DEVICES'] = str(
        training_params.cuda_visible_devices)
    if torch.cuda.is_available() and training_params.use_cuda:
        logger.info('GPU found, running on device {}'.format(
            torch.cuda.current_device()))
    elif training_params.use_cuda:
        logger.warning(
            'GPU not found, running on CPU. Overriding use_cuda to False.')
        training_params.set('use_cuda', False)
    else:
        logger.debug('GPU found, but use_cuda=False, consider using GPU.')

    log_experiment_info(model_name, new_folder, latest_folder)

    hps = get_hps(new_folder, args)

    dataset_path = os.path.join(DATASET_DIR, hps.dataset)
    full_dataset, src, tgt = read_problem(dataset_path, n_sents=None)
    dataset, src, tgt = read_problem(dataset_path,
                                     n_sents=len(full_dataset["train"][0]) //
                                     hps.fraction)

    #DELETE ME
    #print(len(dataset["dev"]))
    #n = len(dataset["dev"]) // hps.fraction
    #print(n)
    #new_dev = (dataset["dev"][0][:n], dataset["dev"][1][:n])
    #dataset["dev"] = new_dev
    #END DELTE ME

    training_params.set('logdir', new_folder)
    log_parameters_info(hps, training_params)
    log_dataset_info(hps, full_dataset, dataset)

    batch_sampler = BatchSampler(dataset,
                                 src_lang=src,
                                 tgt_lang=tgt,
                                 batch_size=training_params.batch_size)
    searchengine = None
    if hps.tm_init:
        logger.info("Using translation memory.")
        if hps.tm_overfitted:
            logger.info("Using overfitted search engine.")
            searchengine = OverfittedSearchEngine()
        else:
            logger.info("Using normal search engine.")
            searchengine = SearchEngine()
        searchengine.load(hps.tm_bin_path)
        searchengine.set_dictionary(full_dataset)
        if hps.tm_50_50:
            searchengine.remove_train_set(dataset["train"][0])

    writer = SummaryWriter(log_dir=training_params.logdir)

    model = s2s.Seq2Seq(src,
                        tgt,
                        hps,
                        training_params,
                        writer=writer,
                        searchengine=searchengine)

    with open(os.path.join(new_folder, "model.meta"), "w") as fout:
        fout.write(repr(model))
    translate_to_all_loggers(repr(model))

    trainer = Trainer(model, batch_sampler, hps, training_params, writer,
                      searchengine)
    trainer.train()
    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
コード例 #2
0
ファイル: test.py プロジェクト: sb-nmt-team/sb-nmt
def main():
    sys.stderr.write(sys.executable + "\n")
    os.environ['CUDA_VISIBLE_DEVICES'] = "6"
    if torch.cuda.is_available():
        sys.stderr.write("Running on device {}\n".format(
            torch.cuda.current_device()))
    parser = argparse.ArgumentParser()
    parser.add_argument('--params')
    parser.add_argument('--training_params')
    parser.add_argument('--model_state')
    #parser.add_argument('--src_path')
    #parser.add_argument('--tgt_path')
    parser.add_argument('--dataset')

    args = parser.parse_args()

    hps = s2s.Seq2Seq.get_default_hparams()
    if args.params:
        with open(args.params, encoding='utf-8') as fin:
            hps = hps.parse_dict(json.load(fin))

    training_params = Trainer.get_default_hparams()
    if args.training_params:
        with open(args.training_params, encoding='utf-8') as fin:
            training_params = training_params.parse_dict(json.load(fin))

    full_dataset, src, tgt = read_problem(args.dataset, n_sents=None)
    searchengine = None
    if hps.tm_init:
        if hps.tm_overfitted:
            searchengine = OverfittedSearchEngine()
        else:
            searchengine = SearchEngine()
        searchengine.load(hps.tm_bin_path)
        searchengine.set_dictionary(full_dataset)
        sys.stderr.write("Using searchengine: {}\n".format(
            searchengine.__class__))

    #dataset_name = "../../preprocessed/he-en/"
    #full_dataset, src, tgt = read_problem(dataset_name, n_sents=None)

    #src = Lang(args.src_path)
    #tgt = Lang(args.tgt_path)
    training_params = training_params.parse_dict({'batch_size': 32})
    writer = WriterMock()
    model = s2s.Seq2Seq(src,
                        tgt,
                        hps,
                        training_params,
                        writer=writer,
                        searchengine=searchengine)
    if training_params.use_cuda:
        model = model.cuda()

    sys.stderr.write("Loading the model state\n")
    state_dict = torch.load(args.model_state)
    model.load_state_dict(state_dict)

    model.eval()
    sys.stderr.write("Ready!\n")
    sents = []
    for sent in sys.stdin:
        sent = sent.strip().split()
        sents.append(sent)

    sents = np.array(sents)
    for sent in run_translation(src, model, sents, training_params):
        print(sent)