n_fit = int(t_fit // Ts)
    input_data = u[0:n_fit]
    state_data = x_noise[0:n_fit]
    u_torch = torch.from_numpy(input_data)
    x_true_torch = torch.from_numpy(state_data)

    # Setup neural model structure
    ss_model = NeuralStateSpaceModel(n_x=2, n_u=1, n_feat=64, init_small=False)
    nn_solution = NeuralStateSpaceSimulator(ss_model)

    # Setup optimizer
    optimizer = optim.Adam(nn_solution.ss_model.parameters(), lr=lr)

    # Scale loss with respect to the initial one
    with torch.no_grad():
        x_est_torch = nn_solution.f_onestep(x_true_torch, u_torch)
        err_init = x_est_torch - x_true_torch
        scale_error = torch.sqrt(torch.mean(err_init**2, dim=0))

    # Training loop
    LOSS = []
    start_time = time.time()
    for itr in range(0, num_iter):

        optimizer.zero_grad()
        x_est_torch = nn_solution.f_onestep(x_true_torch, u_torch)
        err = x_est_torch - x_true_torch
        err_scaled = err / scale_error
        loss_sc = torch.mean(
            (err_scaled)**
            2)  #torch.mean(torch.sq(batch_x[:,1:,:] - batch_x_pred[:,1:,:]))
Beispiel #2
0
    # Fit data to pytorch tensors #
    n_fit = int(len_fit//Ts)
    u_fit = u[0:n_fit]
    x_fit = x_noise[0:n_fit]
    t_fit = t[0:n_fit]
    u_fit_torch = torch.from_numpy(u_fit)
    x_meas_fit_torch = torch.from_numpy(x_fit)

    # Setup optimizer
    params = list(nn_solution.ss_model.parameters())
    optimizer = optim.Adam(params, lr=lr)
    end = time.time()

    # Scale loss with respect to the initial one
    with torch.no_grad():
        x_est_torch = nn_solution.f_onestep(x_meas_fit_torch, u_fit_torch)
        err_init = x_est_torch - x_meas_fit_torch
        scale_error = torch.sqrt(torch.mean((err_init)**2, dim=0))


    LOSS = []
    start_time = time.time()
    # Training loop
    for itr in range(1, num_iter + 1):
        optimizer.zero_grad()

        # Perform one-step ahead prediction
        x_pred_torch = nn_solution.f_onestep(x_meas_fit_torch, u_fit_torch)

        # Compute fit loss
        err = x_pred_torch - x_meas_fit_torch