Esempio n. 1
0
def test_process(mock_data, data_format: str):
    rewards, states, observations, actions, hidden, policy_infos = mock_data

    # Create the rollout
    ro = StepSequence(rewards=rewards,
                      observations=observations,
                      states=states,
                      actions=actions,
                      hidden=hidden)

    if data_format == "numpy":
        # Create the filter (arbitrary values)
        b, a = signal.butter(N=5, Wn=10, fs=100)

        # Filter the signals, but not the time
        ro_proc = StepSequence.process_data(ro,
                                            signal.filtfilt,
                                            fcn_arg_name="x",
                                            exclude_fields=["time"],
                                            b=b,
                                            a=a,
                                            padlen=2,
                                            axis=0)

    else:
        # Transform to PyTorch data and define a simple function
        ro.torch()
        ro_proc = StepSequence.process_data(ro,
                                            lambda x: x * 2,
                                            fcn_arg_name="x",
                                            include_fields=["time"],
                                            fcn_arg_types=to.Tensor)

    assert isinstance(ro_proc, StepSequence)
    assert ro_proc.length == ro.length
Esempio n. 2
0
    rollouts, file_names = load_rollouts_from_dir(args.dir)

    # Create a lowpass Butterworth filter with a cutoff at 50 Hz, and a sampling frequency of the orig system of 500Hz
    b, a = signal.butter(N=10, Wn=args.f_cut, fs=1 / args.dt)

    for ro, fname in zip(rollouts, file_names):
        ro.numpy()
        if args.verbose:
            plot_observations(ro)
            plt.gcf().canvas.set_window_title("Before")

        # Filter the signals, but not the time
        ro_proc = StepSequence.process_data(ro,
                                            signal.filtfilt,
                                            fcn_arg_name="x",
                                            exclude_fields=["time"],
                                            b=b,
                                            a=a,
                                            padlen=150,
                                            axis=0)

        # Downsample all data fields
        ro_proc = StepSequence.process_data(ro_proc,
                                            lambda x: x[::args.factor],
                                            fcn_arg_name="x")

        if args.verbose:
            plot_observations(ro_proc)
            plt.gcf().canvas.set_window_title("After")
            plt.show()

        # Save in a new folder on the same level as the current folder