Exemplo n.º 1
0
def test_dmp_save_and_load_bug_83():
    random_state = np.random.RandomState(0)

    dmp = DMPBehavior(execution_time=1.0, dt=0.001)
    dmp.init(6, 6)
    x0 = np.zeros(2)
    g = np.ones(2)
    dmp.set_meta_parameters(["x0", "g"], [x0, g])
    w = dmp.get_params()
    w = random_state.randn(*w.shape) * 1000.0
    dmp.set_params(w)

    Y, Yd, Ydd = dmp.trajectory()

    try:
        dmp.save("tmp_dmp_model2.yaml")
        dmp.save_config("tmp_dmp_config2.yaml")

        dmp = DMPBehavior(configuration_file="tmp_dmp_model2.yaml")
        dmp.init(6, 6)
        dmp.load_config("tmp_dmp_config2.yaml")
    finally:
        if os.path.exists("tmp_dmp_model2.yaml"):
            os.remove("tmp_dmp_model2.yaml")
        if os.path.exists("tmp_dmp_config2.yaml"):
            os.remove("tmp_dmp_config2.yaml")

    Y2, Yd2, Ydd2 = dmp.trajectory()

    assert_array_almost_equal(Y, Y2)
    assert_array_almost_equal(Yd, Yd2)
    assert_array_almost_equal(Ydd, Ydd2)
Exemplo n.º 2
0
def test_dmp_save_and_load():
    beh_original = DMPBehavior(execution_time=0.853, dt=0.001, n_features=10)
    beh_original.init(3 * n_task_dims, 3 * n_task_dims)

    x0 = np.ones(n_task_dims) * 1.29
    g = np.ones(n_task_dims) * 2.13
    beh_original.set_meta_parameters(["x0", "g"], [x0, g])

    xva = np.zeros(3 * n_task_dims)
    xva[:n_task_dims] = x0

    beh_original.reset()
    t = 0
    while beh_original.can_step():
        eval_loop(beh_original, xva)
        if t == 0:
            assert_array_almost_equal(xva[:n_task_dims], x0)
        t += 1
    assert_array_almost_equal(xva[:n_task_dims], g, decimal=3)
    assert_equal(t, 854)
    assert_equal(beh_original.get_n_params(), n_task_dims * 10)

    try:
        beh_original.save("tmp_dmp_model.yaml")
        beh_original.save_config("tmp_dmp_config.yaml")

        beh_loaded = DMPBehavior(configuration_file="tmp_dmp_model.yaml")
        beh_loaded.init(3 * n_task_dims, 3 * n_task_dims)
        beh_loaded.load_config("tmp_dmp_config.yaml")
    finally:
        if os.path.exists("tmp_dmp_model.yaml"):
            os.remove("tmp_dmp_model.yaml")
        if os.path.exists("tmp_dmp_config.yaml"):
            os.remove("tmp_dmp_config.yaml")

    xva = np.zeros(3 * n_task_dims)
    xva[:n_task_dims] = x0

    beh_loaded.reset()
    t = 0
    while beh_loaded.can_step():
        eval_loop(beh_loaded, xva)
        if t == 0:
            assert_array_almost_equal(xva[:n_task_dims], x0)
        t += 1
    assert_array_almost_equal(xva[:n_task_dims], g, decimal=3)
    assert_equal(t, 854)
    assert_equal(beh_loaded.get_n_params(), n_task_dims * 10)
Exemplo n.º 3
0
def test_dmp_save_and_load2():
    import pickle
    with open("reload_test_trajectory_and_dt.pickle", "r") as f:
        recorded_trajectory, dt = pickle.load(f)

    artificial_trajectory = np.zeros((500, 1))
    artificial_trajectory[:, 0] = np.sin(np.linspace(0, 500, 1))
    trajectories = [artificial_trajectory, recorded_trajectory]
    for trajectory in trajectories:
        n_steps = trajectory.shape[0]
        n_task_dims = trajectory.shape[1]
        execution_time = dt * (n_steps - 1)
        beh_original = DMPBehavior(execution_time=execution_time,
                                   dt=dt,
                                   n_features=20)
        beh_original.init(3 * n_task_dims, 3 * n_task_dims)

        x0 = trajectory[0, :]
        g = trajectory[-1, :]
        beh_original.set_meta_parameters(["x0", "g"], [x0, g])

        X = np.zeros((n_task_dims, n_steps, 1))
        X[:, :, 0] = np.swapaxes(trajectory, axis1=1, axis2=0)
        beh_original.imitate(X, alpha=0.01)

        xva = np.zeros(3 * n_task_dims)
        xva[:n_task_dims] = x0

        beh_original.reset()
        imitated_trajectory = beh_original.trajectory()
        try:
            beh_original.save("tmp_dmp_model.yaml")
            beh_original.save_config("tmp_dmp_config.yaml")

            beh_loaded = DMPBehavior(configuration_file="tmp_dmp_model.yaml")
            beh_loaded.init(3 * n_task_dims, 3 * n_task_dims)
            beh_loaded.load_config("tmp_dmp_config.yaml")
        finally:
            if os.path.exists("tmp_dmp_model.yaml"):
                os.remove("tmp_dmp_model.yaml")
            if os.path.exists("tmp_dmp_config.yaml"):
                os.remove("tmp_dmp_config.yaml")

        beh_loaded.reset()
        reimitated_trajectory = beh_loaded.trajectory()
        assert_array_almost_equal(imitated_trajectory,
                                  reimitated_trajectory,
                                  decimal=4)