示例#1
0
def test(func, save_fig_path):
    func.load_state_dict(torch.load(save_fig_path + '/best_model.pth'))
    '''
    solver = Dopri5(func, t0=options['t0'], t1=options['t1'], h=options['h'],
                  rtol=options['rtol'], atol=options['atol'], neval_max=options['neval_max'],
                  safety=options['safety'], keep_small_step=options['keep_small_step'])
    out = solver.integrate(initial_condition, t_eval=t_list)
    '''

    out = odesolve(func, initial_condition,
                   options=options)  # , time_points=t_list)
    out_all = out.permute(1, 0, -1)

    out_all_data = out_all.data.cpu().numpy()
    position = out_all[..., :9]

    plt.clf()
    # Create 3D axes
    ax = fig.add_subplot(131, projection="3d")
    # Plot the orbits
    ax.plot(r1_sol[:, 0],
            r1_sol[:, 1],
            r1_sol[:, 2],
            '-',
            label='Ground truth',
            linewidth=2.0)
    ax.plot(out_all_data[0, :, 0],
            out_all_data[0, :, 1],
            out_all_data[0, :, 2],
            '--',
            label='Fitting',
            linewidth=2.0)
    if set_legend:
        ax.set_xlabel("x-coordinate", fontsize=14)
        ax.set_ylabel("y-coordinate", fontsize=14)
        ax.set_zlabel("z-coordinate", fontsize=14)
        # ax.set_title("Star A trajectory", fontsize=14)
        ax.legend()  #loc="upper left", fontsize=14)

    # Create 3D axes
    ax = fig.add_subplot(132, projection="3d")
    # Plot the orbits
    ax.plot(r2_sol[:, 0],
            r2_sol[:, 1],
            r2_sol[:, 2],
            '-',
            label='Ground truth',
            linewidth=2.0)
    ax.plot(out_all_data[0, :, 3],
            out_all_data[0, :, 4],
            out_all_data[0, :, 5],
            '--',
            label='Fitting',
            linewidth=2.0)
    if set_legend:
        ax.set_xlabel("x-coordinate", fontsize=14)
        ax.set_ylabel("y-coordinate", fontsize=14)
        ax.set_zlabel("z-coordinate", fontsize=14)
        # ax.set_title("Star B trajectory", fontsize=14)
        ax.legend()  #loc="upper left", fontsize=14)

    # Create 3D axes
    ax = fig.add_subplot(133, projection="3d")
    # Plot the orbits
    ax.plot(r3_sol[:, 0],
            r3_sol[:, 1],
            r3_sol[:, 2],
            '-',
            label='Ground truth',
            linewidth=2.0)
    ax.plot(out_all_data[0, :, 6],
            out_all_data[0, :, 7],
            out_all_data[0, :, 8],
            '--',
            label='Fitting',
            linewidth=2.0)
    if set_legend:
        ax.set_xlabel("x-coordinate", fontsize=14)
        ax.set_ylabel("y-coordinate", fontsize=14)
        ax.set_zlabel("z-coordinate", fontsize=14)
        # ax.set_title("Star C trajectory", fontsize=14)
        ax.legend()  #loc="upper left", fontsize=14)

    plt.savefig('%s/best_model.png' % (save_fig_path))
    plt.draw()
    plt.pause(0.001)

    dif = position - trajectory
    dif = torch.sum(dif**2, -1, keepdim=False)  # 1 x N
    dif1 = torch.squeeze(dif)  # N
    dif1 = dif1 * loss_weight_decay
    loss = torch.sum(torch.abs(dif1)) / float(time_span.shape[0])

    ###################################################################
    #                  plot into animation                            #
    ###################################################################
    generate_animation_figures(out_all_data[0, :, :3], out_all_data[0, :, 3:6],
                               out_all_data[0, :, 6:9], r1_sol, r2_sol, r3_sol,
                               np.array(t_list), animation_fig_path)

    return torch.sum(dif).item() / float(time_span.shape[0]), loss.item()
示例#2
0
best_loss = np.inf
if TrainMode:
    for _epoch in range(num_epochs):
        print('M1 {}, estimated m1 {}'.format(m1, func.m1.item()))
        print('M2 {}, estimated m2 {}'.format(m2, func.m2.item()))
        print('M3 {}, estimated m3 {}'.format(m3, func.m3.item()))

        lr *= lr_decay
        adjust_learning_rate(optimizer, lr)
        optimizer.zero_grad()

        func.eval()

        solver = odesolve(func,
                          y0=initial_condition,
                          options=options,
                          return_solver=True)  #, time_points=t_list)
        out = solver.integrate(y0=initial_condition,
                               t0=options['t0'],
                               t_eval=options['t_eval'])
        out_tmp = solver.evaluate_dense_mode(t_eval=options['t_eval'])
        print(torch.sum((out_tmp - out)**2))

        out_all2 = out_tmp.permute(1, 0, -1)
        '''
        solver = Dopri5(func, t0=options['t0'],t1=options['t1'],h=options['h'],
                      rtol =options['rtol'],atol=options['atol'],neval_max=options['neval_max'],
                      safety=options['safety'], keep_small_step=options['keep_small_step'])
        out2 = solver.integrate(initial_condition, t_eval = t_list)
        out_all2 = out2.permute(1,0,-1)
        '''
示例#3
0
loss_history = []

best_loss = np.inf
if TrainMode:
    for _epoch in range(num_epochs):
        print('M1 {}, estimated m1 {}'.format(m1, func.m1.item()))
        print('M2 {}, estimated m2 {}'.format(m2, func.m2.item()))
        print('M3 {}, estimated m3 {}'.format(m3, func.m3.item()))

        lr *= lr_decay
        adjust_learning_rate(optimizer, lr)
        optimizer.zero_grad()

        func.eval()

        out = odesolve(func, initial_condition,
                       options=options)  #, time_points=t_list)
        out_all2 = out.permute(1, 0, -1)
        '''
        solver = Dopri5(func, t0=options['t0'],t1=options['t1'],h=options['h'],
                      rtol =options['rtol'],atol=options['atol'],neval_max=options['neval_max'],
                      safety=options['safety'], keep_small_step=options['keep_small_step'])
        out2 = solver.integrate(initial_condition, t_eval = t_list)
        out_all2 = out2.permute(1,0,-1)
        '''

        out_all_data = out_all2.data.cpu().numpy()
        position = out_all2[..., :9]
        # loss = torch.sum((position - trajectory) ** 2)

        dif = position - trajectory
        dif = torch.sum(dif**2, -1, keepdim=False)  # 1 x N
示例#4
0
    def forward(self, t):
        return self.w

func = Func()
history = History()

time_span = sci.linspace(0, 10.0, 1000)  # 20 orbital periods and 500 points
t_list = time_span.tolist()

# configure training options
options = {}
options.update({'method': 'sym12async'})
options.update({'t0': 0.0})
options.update({'t1': 10.0})
options.update({'h': None})
options.update({'rtol': 1e-3})
options.update({'atol': 1e-4})
options.update({'print_neval': False})
options.update({'neval_max': 1000000})
options.update({'safety': None})
options.update({'t_eval':t_list})
options.update({'dense_output':True})
options.update({'interpolation_method':'cubic'})


out = odesolve(func, history(0.0), options)
out = out.data.cpu().numpy()
plt.plot(out[:,0], out[:,1])
plt.show()