Ejemplo n.º 1
0
def test_load_and_infer_improvement(nntxt_idx, parameter_format, dataset_sample_num):
    '''This case tests improvement features, comparing legacy implementation,
    legacy cannot load or infer successfully, while refactor-ed is OK.
    '''
    with generate_case_from_nntxt_str(NNTXT_IMPROVEMENT_CASES[nntxt_idx], parameter_format, dataset_sample_num) as nnp_file:
        with pytest.raises(ValueError) as excinfo:
            ref_info = ref_load(nnp_file)
            ref_result = partial(
                common_forward, forward_func=_ref_forward)(ref_info)
            print(excinfo)

        info = load.load(nnp_file)
        result = partial(common_forward, forward_func=_forward)(info)
Ejemplo n.º 2
0
def test_load_and_infer_equivalence(nntxt_idx, parameter_format, dataset_sample_num):
    '''These cases tends to test equivalence before and after
    refactoring NNP load functions. The scope of refactor includes network part and load function.
    This test firstly generated .nnp from nntxt_str, according to specified parameter_format
    and replace dataset's uri with a temporarily generated random dataset, then performs inferring
    operation similar to what is done in cli/forward.py.
    '''
    with generate_case_from_nntxt_str(NNTXT_EQUIVALENCE_CASES[nntxt_idx], parameter_format, dataset_sample_num) as nnp_file:
        ref_info = ref_load(nnp_file)
        ref_result = partial(
            common_forward, forward_func=_ref_forward)(ref_info)

        info = load.load(nnp_file)
        result = partial(common_forward, forward_func=_forward)(info)

    assert_tensor_equal(result, ref_result)
Ejemplo n.º 3
0
def test_load_and_train_equivalence(nntxt_idx, parameter_format,
                                    dataset_sample_num, batch_size):
    '''These cases tends to test equivalence before and after refactoring.
    The operation is similar to what is done in cli/train.py.
    '''
    # for debugging
    save_v = False
    output_network_topology = False
    verbose = False
    m_iter = 10

    class Callback:
        pass

    legacy_config = TrainConfig()
    legacy_config.on_iter = None
    legacy_config.save_optimizer_variable = False
    legacy_config.save_evaluation_variable = False
    legacy_config.start_iteration = 0
    legacy_config.end_iteration = 10
    legacy_config.enable_save_variable = save_v
    legacy_cb = Callback()
    legacy_cb.forward = lambda o: o.network.forward(o.forward_sequence)
    legacy_cb.backward = lambda o, b: o.network.backward(
        o.backward_sequence, b)
    legacy_config.cb = legacy_cb
    legacy_config.impl = "legacy"

    new_config = TrainConfig()
    new_config.on_iter = None
    new_config.save_optimizer_variable = False
    new_config.save_evaluation_variable = False
    new_config.start_iteration = 0
    new_config.end_iteration = 10
    new_config.enable_save_variable = save_v
    new_cb = Callback()
    new_cb.forward = lambda x: x.target.forward(clear_no_need_grad=True)
    new_cb.backward = lambda x, b: x.target.backward(clear_buffer=True)
    new_config.cb = new_cb
    new_config.impl = "new"

    with generate_case_from_nntxt_str(NNTXT_EQUIVALENCE_CASES[nntxt_idx],
                                      parameter_format, dataset_sample_num,
                                      batch_size) as nnp_file:
        ref_result = []
        result = []
        nn.clear_parameters()
        info = ref_load(nnp_file, batch_size=batch_size)
        for cost, error in partial(train, config=legacy_config)(info):
            ref_result.append((cost, error))

        nn.clear_parameters()
        info = load.load(nnp_file, batch_size=batch_size)

        if output_network_topology:
            for n, opt in info.optimizers.items():
                print(n)
                opt.network.execute_on_proto(Verifier())

        for cost, error in partial(train, config=new_config)(info):
            result.append((cost, error))

        for i, ((cost_ref, error_ref),
                (cost, error)) in enumerate(zip(ref_result, result)):
            if verbose:
                print("{}: cost: {} <--> {}".format(i, cost_ref, cost))
                print("{}: error: {} <--> {}".format(i, error_ref, error))
            assert_allclose(np.array([cost_ref, error_ref]),
                            np.array([cost, error]),
                            rtol=1e-2,
                            atol=1e-3,
                            err_msg="Error: {}".format(nntxt_idx))