Example #1
0
def test(model_file, test_file, device=-1):
    context = utils.Saver.load_context(model_file)
    if context.seed is not None:
        utils.set_random_seed(context.seed, device)

    test_dataset = context.loader.load(test_file, train=False, bucketing=True)
    kwargs = dict(context)
    if context.model_config is not None:
        kwargs.update(context.model_config)
    model = _build_parser(**dict(kwargs))
    chainer.serializers.load_npz(model_file, model)
    if device >= 0:
        chainer.cuda.get_device_from_id(device).use()
        model.to_gpu(device)

    pbar = training.listeners.ProgressBar(lambda n: tqdm(total=n))
    pbar.init(len(test_dataset))
    evaluator = Evaluator(model, context.loader.rel_map, test_file,
                          logging.getLogger())
    utils.chainer_train_off()
    for batch in test_dataset.batch(context.batch_size,
                                    colwise=True,
                                    shuffle=False):
        xs, ts = batch[:-1], batch[-1]
        ys = model.forward(*xs)
        evaluator.on_batch_end({'train': False, 'xs': xs, 'ys': ys, 'ts': ts})
        pbar.update(len(ts))
    evaluator.on_epoch_validate_end({})