コード例 #1
0
ファイル: test_only.py プロジェクト: alvin-leong/FocusSeq2Seq
    print('Current user:'******'Dataset directory:', data_dir)

    print('PyTorch Version:', torch.__version__)
    print('# CPUs:', multiprocessing.cpu_count())
    print('# GPUs:', torch.cuda.device_count())
    print('Current cuda device:', torch.cuda.current_device())
    print('Device name:', torch.cuda.get_device_name(0))

    print('Seed:', config.seed)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)
    np.random.seed(config.seed)
    random.seed(config.seed)

    train_loader, val_loader, test_loader, word2id, id2word = get_loader(
        config, data_dir)

    PAD_ID, UNK_ID = word2id['<pad>'], word2id['<unk>']
    vocab_size = len(word2id)
    print('Loaded data loaders & Vocab!')
    print('Vocab Size:', vocab_size)

    model = build_model(config, word2id, id2word)

    print('#===== Parameters =====#')
    for name, p in model.named_parameters():
        print(name, '\t', list(p.size()))

    print('#==== Weight Initialization ====#')
    if config.model == 'NQG':
        for name, p in model.named_parameters():
コード例 #2
0
ファイル: evaluate.py プロジェクト: zheng5yu9/FocusSeq2Seq
    return metric_result, hypotheses, best_hypothesis, hyp_focus, hyp_attention


if __name__ == '__main__':
    from pathlib import Path
    current_dir = Path(__file__).resolve().parent

    import configs
    config = configs.get_config()
    print(config)

    from build_utils import get_loader, build_model, get_ckpt_name

    # Build Data Loader
    data_dir = current_dir.joinpath(config.data + '_out')
    _, _, test_loader, word2id, id2word = get_loader(config, data_dir)

    # Build Model
    model = build_model(config, word2id, id2word)
    model.to(device)

    # Load Model from checkpoint
    ckpt_dir = Path(f"./ckpt/{config.model}/").resolve()
    filename = get_ckpt_name(config)
    filename += f"_epoch{config.load_ckpt}.pkl"
    ckpt_path = ckpt_dir.joinpath(filename)
    ckpt = torch.load(ckpt_path)
    model.load_state_dict(ckpt['model'])
    print('Loaded model from', ckpt_path)

    # Run Evaluation