Example #1
0
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        IsInCollision = pendulum.IsInCollision
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        obs_file = None
        obc_file = None
        dynamics = pendulum.dynamics
        jax_dynamics = pendulum.jax_dynamics
        enforce_bounds = pendulum.enforce_bounds
        cae = cae_identity
        mlp = MLP
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        IsInCollision = cartpole.IsInCollision
        normalize = cartpole.normalize
        unnormalize = cartpole.unnormalize
        obs_file = None
        obc_file = None
        dynamics = cartpole.dynamics
        jax_dynamics = cartpole.jax_dynamics
        enforce_bounds = cartpole.enforce_bounds
        cae = CAE_acrobot_voxel_2d
        mlp = mlp_acrobot.MLP
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'cartpole')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs':
        system = _sst_module.PSOPTAcrobot()
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d
        mlp = mlp_acrobot.MLP
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'acrobot')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs_2':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d_2
        mlp = mlp_acrobot.MLP2
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'acrobot')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs_8':
        system = _sst_module.PSOPTAcrobot()
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'acrobot')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)

    jac_A = jax.jacfwd(jax_dynamics, argnums=0)
    jac_B = jax.jacfwd(jax_dynamics, argnums=1)
    mpNet0 = KMPNet(args.total_input_size, args.AE_input_size,
                    args.mlp_input_size, args.output_size, cae, mlp)
    mpNet1 = KMPNet(args.total_input_size, args.AE_input_size,
                    args.mlp_input_size, args.output_size, cae, mlp)

    # load previously trained model if start epoch > 0
    model_path = 'kmpnet_epoch_%d_direction_0.pkl' % (args.start_epoch)
    if args.start_epoch > 0:
        load_net_state(mpNet0, os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(args.model_path, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)
    if torch.cuda.is_available():
        mpNet0.cuda()
        mpNet0.mlp.cuda()
        mpNet0.encoder.cuda()
        if args.opt == 'Adagrad':
            mpNet0.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpNet0.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpNet0.set_opt(torch.optim.SGD,
                           lr=args.learning_rate,
                           momentum=0.9)
    if args.start_epoch > 0:
        load_opt_state(mpNet0, os.path.join(args.model_path, model_path))

    # load previously trained model if start epoch > 0
    model_path = 'kmpnet_epoch_%d_direction_1.pkl' % (args.start_epoch)
    if args.start_epoch > 0:
        load_net_state(mpNet1, os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(args.model_path, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)
    if torch.cuda.is_available():
        mpNet1.cuda()
        mpNet1.mlp.cuda()
        mpNet1.encoder.cuda()
        if args.opt == 'Adagrad':
            mpNet1.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpNet1.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpNet1.set_opt(torch.optim.SGD,
                           lr=args.learning_rate,
                           momentum=0.9)
    if args.start_epoch > 0:
        load_opt_state(mpNet1, os.path.join(args.model_path, model_path))

    _, waypoint_dataset, waypoint_targets, _, _, _, _, _ = data_loader.load_train_dataset(
        1, 2, args.data_folder, obs_f, 1, dynamics, enforce_bounds, system,
        0.02, 20)

    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N,
                                                       args.seen_NP,
                                                       args.data_folder, obs_f,
                                                       args.seen_s,
                                                       args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            args.unseen_N, args.unseen_NP, args.data_folder, obs_f,
            args.unseen_s, args.unseen_sp)
    # test
    # testing

    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.
    T = 1
    # unnormalize function
    normalize_func = lambda x: normalize(x, args.world_size)
    unnormalize_func = lambda x: unnormalize(x, args.world_size)
    # seen
    if args.seen_N > 0:
        time_file = os.path.join(
            args.model_path, 'time_seen_epoch_%d_mlp.p' % (args.start_epoch))
        fes_path_, valid_path_ = eval_tasks_mpnet(
            mpNet0, mpNet1, args.env_type, seen_test_data, args.model_path,
            'seen', normalize_func, unnormalize_func, dynamics, jac_A, jac_B,
            enforce_bounds, IsInCollision)
    # unseen
    if args.unseen_N > 0:
        time_file = os.path.join(
            args.model_path, 'time_unseen_epoch_%d_mlp.p' % (args.start_epoch))
        fes_path_, valid_path_ = eval_tasks_mpnet(
            mpNet0, mpNet1, args.env_type, unseen_test_data, args.model_path,
            'unseen', normalize_func, unnormalize_func, dynamics, jac_A, jac_B,
            enforce_bounds, IsInCollision)
Example #2
0
def main(args):
    # set seed
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        obs_file = None
        obc_file = None
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        obs_file = None
        obc_file = None
        obs_f = True
        obs_width = 4.0
        step_sz = 0.002
        psopt_system = _sst_module.PSOPTCartPole()
        cpp_propagator = _sst_module.SystemPropagator()

        #system = standard_cpp_systems.RectangleObs(obs, 4., 'cartpole')
        dynamics = lambda x, u, t: cpp_propagator.propagate(
            psopt_system, x, u, t)

        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 100
    elif args.env_type == 'cartpole_obs_2':
        obs_file = None
        obc_file = None
        obs_f = True
        obs_width = 4.0
        step_sz = 0.002
        psopt_system = _sst_module.PSOPTCartPole()
        cpp_propagator = _sst_module.SystemPropagator()

        #system = standard_cpp_systems.RectangleObs(obs, 4., 'cartpole')
        dynamics = lambda x, u, t: cpp_propagator.propagate(
            psopt_system, x, u, t)

        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP2
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 100
    elif args.env_type == 'cartpole_obs_3':
        obs_file = None
        obc_file = None
        obs_f = True
        obs_width = 4.0
        step_sz = 0.002
        psopt_system = _sst_module.PSOPTCartPole()
        cpp_propagator = _sst_module.SystemPropagator()

        #system = standard_cpp_systems.RectangleObs(obs, 4., 'cartpole')
        dynamics = lambda x, u, t: cpp_propagator.propagate(
            psopt_system, x, u, t)

        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP4
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 200
    elif args.env_type == 'cartpole_obs_4':
        obs_file = None
        obc_file = None
        obs_f = True
        obs_width = 4.0
        step_sz = 0.002
        psopt_system = _sst_module.PSOPTCartPole()
        cpp_propagator = _sst_module.SystemPropagator()

        #system = standard_cpp_systems.RectangleObs(obs, 4., 'cartpole')
        dynamics = lambda x, u, t: cpp_propagator.propagate(
            psopt_system, x, u, t)

        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 200

    elif args.env_type == 'acrobot_obs':
        obs_file = None
        obc_file = None

        obs_f = True
        obs_width = 6.0

        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'acrobot')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)

    mpnet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp, None)
    # load net
    # load previously trained model if start epoch > 0
    model_dir = args.model_dir
    if args.loss == 'mse':
        if args.multigoal == 0:
            model_dir = model_dir + args.env_type + "_lr%f_%s_step_%d/" % (
                args.learning_rate, args.opt, args.num_steps)
        else:
            model_dir = model_dir + args.env_type + "_lr%f_%s_step_%d_multigoal/" % (
                args.learning_rate, args.opt, args.num_steps)
    else:
        if args.multigoal == 0:
            model_dir = model_dir + args.env_type + "_lr%f_%s_loss_%s_step_%d/" % (
                args.learning_rate, args.opt, args.loss, args.num_steps)
        else:
            model_dir = model_dir + args.env_type + "_lr%f_%s_loss_%s_step_%d_multigoal/" % (
                args.learning_rate, args.opt, args.loss, args.num_steps)

    print(model_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    model_path = 'kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
        args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet, os.path.join(model_dir, model_path))
        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(model_dir, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet.cuda()
        mpnet.mlp.cuda()
        mpnet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)
    if args.start_epoch > 0:
        #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
        load_opt_state(mpnet, os.path.join(model_dir, model_path))

    mpnet.eval()

    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N,
                                                       args.seen_NP,
                                                       args.data_folder, obs_f,
                                                       args.seen_s,
                                                       args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            args.unseen_N, args.unseen_NP, args.data_folder, obs_f,
            args.unseen_s, args.unseen_sp)
    # test
    # testing

    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.

    # find path

    plt.ion()
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_autoscale_on(True)
    hl, = ax.plot([], [], 'b')

    #hl_real, = ax.plot([], [], 'r')
    def update_line(h, ax, new_data):
        h.set_data(np.append(h.get_xdata(), new_data[0]),
                   np.append(h.get_ydata(), new_data[1]))
        #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
        #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

    def draw_update_line(ax):
        ax.relim()
        ax.autoscale_view()
        fig.canvas.draw()
        fig.canvas.flush_events()

    # randomly pick up a point in the data, and find similar data in the dataset
    # plot the next point
    obc, obs, paths, sgs, path_lengths, controls, costs = seen_test_data
    for envi in range(2):
        for pathi in range(10):
            obs_i = obs[envi]
            new_obs_i = []
            obs_i = obs[envi]
            plan_res_path = []
            plan_time_path = []
            plan_cost_path = []
            data_cost_path = []
            for k in range(len(obs_i)):
                obs_pt = []
                obs_pt.append(obs_i[k][0] - obs_width / 2)
                obs_pt.append(obs_i[k][1] - obs_width / 2)
                obs_pt.append(obs_i[k][0] - obs_width / 2)
                obs_pt.append(obs_i[k][1] + obs_width / 2)
                obs_pt.append(obs_i[k][0] + obs_width / 2)
                obs_pt.append(obs_i[k][1] + obs_width / 2)
                obs_pt.append(obs_i[k][0] + obs_width / 2)
                obs_pt.append(obs_i[k][1] - obs_width / 2)
                new_obs_i.append(obs_pt)
            obs_i = new_obs_i

            # visualization
            plt.ion()
            fig = plt.figure()
            ax = fig.add_subplot(121)
            ax_vel = fig.add_subplot(122)
            #ax.set_autoscale_on(True)
            ax.set_xlim(-30, 30)
            ax.set_ylim(-np.pi, np.pi)
            ax_vel.set_xlim(-40, 40)
            ax_vel.set_ylim(-2, 2)

            hl, = ax.plot([], [], 'b')
            #hl_real, = ax.plot([], [], 'r')
            hl_for, = ax.plot([], [], 'g')
            hl_back, = ax.plot([], [], 'r')
            hl_for_mpnet, = ax.plot([], [], 'lightgreen')
            hl_back_mpnet, = ax.plot([], [], 'salmon')

            #print(obs)
            def update_line(h, ax, new_data):
                new_data = wrap_angle(new_data, propagate_system)
                h.set_data(np.append(h.get_xdata(), new_data[0]),
                           np.append(h.get_ydata(), new_data[1]))
                #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
                #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

            def remove_last_k(h, ax, k):
                h.set_data(h.get_xdata()[:-k], h.get_ydata()[:-k])

            def draw_update_line(ax):
                #ax.relim()
                #ax.autoscale_view()
                fig.canvas.draw()
                fig.canvas.flush_events()
                #plt.show()

            def wrap_angle(x, system):
                circular = system.is_circular_topology()
                res = np.array(x)
                for i in range(len(x)):
                    if circular[i]:
                        # use our previously saved version
                        res[i] = x[i] - np.floor(x[i] /
                                                 (2 * np.pi)) * (2 * np.pi)
                        if res[i] > np.pi:
                            res[i] = res[i] - 2 * np.pi
                return res

            dx = 1
            dtheta = 0.1
            feasible_points = []
            infeasible_points = []
            imin = 0
            imax = int(2 * 30. / dx)
            jmin = 0
            jmax = int(2 * np.pi / dtheta)

            for i in range(imin, imax):
                for j in range(jmin, jmax):
                    x = np.array([dx * i - 30, 0., dtheta * j - np.pi, 0.])
                    if IsInCollision(x, obs_i):
                        infeasible_points.append(x)
                    else:
                        feasible_points.append(x)
            feasible_points = np.array(feasible_points)
            infeasible_points = np.array(infeasible_points)
            print('feasible points')
            print(feasible_points)
            print('infeasible points')
            print(infeasible_points)
            ax.scatter(feasible_points[:, 0],
                       feasible_points[:, 2],
                       c='yellow')
            ax.scatter(infeasible_points[:, 0],
                       infeasible_points[:, 2],
                       c='pink')
            #for i in range(len(data)):
            #    update_line(hl, ax, data[i])
            draw_update_line(ax)
            #state_t = start_state

            xs = paths[envi][pathi]
            us = controls[envi][pathi]
            ts = costs[envi][pathi]
            # propagate data
            p_start = xs[0]
            detail_paths = [p_start]
            detail_controls = []
            detail_costs = []
            state = [p_start]
            control = []
            cost = []
            for k in range(len(us)):
                #state_i.append(len(detail_paths)-1)
                max_steps = int(ts[k] / step_sz)
                accum_cost = 0.
                #print('p_start:')
                #print(p_start)
                #print('data:')
                #print(paths[i][j][k])
                # modify it because of small difference between data and actual propagation
                p_start = xs[k]
                state[-1] = xs[k]
                for step in range(1, max_steps + 1):
                    p_start = dynamics(p_start, us[k], step_sz)
                    p_start = enforce_bounds(p_start)
                    detail_paths.append(p_start)
                    accum_cost += step_sz
                    if (step % 1 == 0) or (step == max_steps):
                        state.append(p_start)
                        #print('control')
                        #print(controls[i][j])
                        cost.append(accum_cost)
                        accum_cost = 0.
            #print('p_start:')
            #print(p_start)
            #print('data:')
            #print(paths[i][j][-1])
            state[-1] = xs[-1]
            #print(len(state))

            xs_to_plot = np.array(state)
            for i in range(len(xs_to_plot)):
                xs_to_plot[i] = wrap_angle(xs_to_plot[i], psopt_system)
            ax.scatter(xs_to_plot[:, 0], xs_to_plot[:, 2], c='green')
            # draw start and goal
            #ax.scatter(start_state[0], goal_state[0], marker='X')
            draw_update_line(ax)
            ax_vel.scatter(xs_to_plot[:, 1],
                           xs_to_plot[:, 3],
                           c='green',
                           s=0.1)
            draw_update_line(ax_vel)

            plt.waitforbuttonpress()

            # visualize mPNet path
            mpnet_paths = []
            state = xs[0]
            #for k in range(int(len(xs_to_plot)/args.num_steps)):
            for k in range(50):
                mpnet_paths.append(state)
                bi = np.concatenate([state, xs[-1]])
                bi = np.array([bi])
                bi = torch.from_numpy(bi).type(torch.FloatTensor)
                print(bi)
                bi = normalize(bi, args.world_size)
                bi = to_var(bi)
                if obc is None:
                    bobs = None
                else:
                    bobs = np.array([obc[envi]]).astype(np.float32)
                    print(bobs.shape)
                    bobs = torch.FloatTensor(bobs)
                    bobs = to_var(bobs)
                bt = mpnet(bi, bobs).cpu()
                bt = unnormalize(bt, args.world_size)
                bt = bt.detach().numpy()
                print(bt.shape)
                state = bt[0]

            print(mpnet_paths)
            xs_to_plot = np.array(mpnet_paths)
            print(len(xs_to_plot))
            for i in range(len(xs_to_plot)):
                xs_to_plot[i] = wrap_angle(xs_to_plot[i], psopt_system)
            ax.scatter(xs_to_plot[:, 0], xs_to_plot[:, 2], c='lightgreen')
            # draw start and goal
            #ax.scatter(start_state[0], goal_state[0], marker='X')
            draw_update_line(ax)
            ax_vel.scatter(xs_to_plot[:, 1], xs_to_plot[:, 3], c='lightgreen')
            draw_update_line(ax_vel)
            plt.waitforbuttonpress()
Example #3
0
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    #torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    #if torch.cuda.is_available():
    #    torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        obs_file = None
        obc_file = None
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        step_sz = 0.002
        num_steps = 21
        goal_radius = 1.5
        random_seed = 0
        delta_near = 2.0
        delta_drain = 1.2
        cost_threshold = 1.2
        min_time_steps = 10
        max_time_steps = 200
        integration_step = 0.002
        obs_f = True
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTCartPole()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        obs_width = 4.0
        IsInCollision = cartpole_IsInCollision
        enforce_bounds = cartpole_enforce_bounds

    elif args.env_type == 'acrobot_obs':
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 200, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0
        IsInCollision = acrobot_IsInCollision
        enforce_bounds = acrobot_enforce_bounds

    if args.env_type == 'pendulum':
        step_sz = 0.002
        num_steps = 20

    elif args.env_type in [
            'acrobot_obs', 'acrobot_obs_2', 'acrobot_obs_3', 'acrobot_obs_4',
            'acrobot_obs_8'
    ]:
        #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
        obs_width = 6.0
        step_sz = 0.02
        num_steps = 21
        goal_radius = 2.0
        random_seed = 0
        delta_near = 0.1
        delta_drain = 0.05

    # load previously trained model if start epoch > 0
    #model_path='kmpnet_epoch_%d_direction_0_step_%d.pkl' %(args.start_epoch, args.num_steps)
    mlp_path = os.path.join(os.getcwd() + '/c++/',
                            'acrobot_mlp_annotated_test_gpu.pt')
    encoder_path = os.path.join(os.getcwd() + '/c++/',
                                'acrobot_encoder_annotated_test_cpu.pt')
    print('mlp_path:')
    print(mlp_path)

    #####################################################
    def plan_one_path(obs_i, obs, obc, detailed_data_path, data_path,
                      start_state, goal_state, goal_inform_state, cost_i,
                      max_iteration, out_queue_t, out_queue_cost, random_seed):
        if args.env_type == 'pendulum':
            system = standard_cpp_systems.PSOPTPendulum()
            bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
            step_sz = 0.002
            num_steps = 20
            traj_opt = lambda x0, x1: bvp_solver.solve(x0, x1, 200, num_steps,
                                                       1, 20, step_sz)

        elif args.env_type == 'cartpole_obs':
            #system = standard_cpp_systems.RectangleObs(obs[i], 4.0, 'cartpole')
            obs_width = 4.0
            psopt_system = _sst_module.PSOPTCartPole()
            propagate_system = standard_cpp_systems.RectangleObs(
                obs, 4., 'cartpole')
            #distance_computer = propagate_system.distance_computer()
            distance_computer = _sst_module.euclidean_distance(
                np.array(propagate_system.is_circular_topology()))
            step_sz = 0.002
            num_steps = 21
            goal_radius = 1.5
            random_seed = 0
            delta_near = 2.0
            delta_drain = 1.2
            #delta_near=.2
            #delta_drain=.1
            cost_threshold = 1.05
            min_time_steps = 10
            max_time_steps = 200
            #min_time_steps = 5
            #max_time_steps = 400
            integration_step = 0.002
        elif args.env_type in [
                'acrobot_obs', 'acrobot_obs_2', 'acrobot_obs_3',
                'acrobot_obs_4', 'acrobot_obs_8'
        ]:
            #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
            obs_width = 6.0
            psopt_system = _sst_module.PSOPTAcrobot()
            propagate_system = standard_cpp_systems.RectangleObs(
                obs, 6., 'acrobot')
            distance_computer = propagate_system.distance_computer()
            #distance_computer = _sst_module.euclidean_distance(np.array(propagate_system.is_circular_topology()))
            step_sz = 0.02
            num_steps = 21
            goal_radius = 2.0
            random_seed = 0
            delta_near = 1.0
            delta_drain = 0.5
            cost_threshold = 1.05
            min_time_steps = 5
            max_time_steps = 100
            integration_step = 0.02
        planner = _sst_module.SSTWrapper(
            state_bounds=propagate_system.get_state_bounds(),
            control_bounds=propagate_system.get_control_bounds(),
            distance=distance_computer,
            start_state=start_state,
            goal_state=goal_state,
            goal_radius=goal_radius,
            random_seed=random_seed,
            sst_delta_near=delta_near,
            sst_delta_drain=delta_drain)
        #print('creating planner...')
        # generate a path by using SST to plan for some maximal iterations
        time0 = time.time()

        for i in range(max_iteration):
            planner.step(propagate_system, min_time_steps, max_time_steps,
                         integration_step)

            # early break for initial path
            solution = planner.get_solution()
            if solution is not None:
                #print('solution found already.')
                # based on cost break
                xs, us, ts = solution
                t_sst = np.sum(ts)
                #print(t_sst)
                #print(cost_i)
                if t_sst <= cost_i * cost_threshold:
                    print('solved in %d iterations' % (i))
                    break
        plan_time = time.time() - time0
        solution = planner.get_solution()
        xs, us, ts = solution
        print(np.linalg.norm(np.array(xs[-1]) - goal_state))
        """
        # visualization
        plt.ion()
        fig = plt.figure()
        ax = fig.add_subplot(111)
        #ax.set_autoscale_on(True)
        #ax.set_xlim(-30, 30)
        ax.set_xlim(-np.pi, np.pi)
        ax.set_ylim(-np.pi, np.pi)
        hl, = ax.plot([], [], 'b')
        #hl_real, = ax.plot([], [], 'r')
        hl_for, = ax.plot([], [], 'g')
        hl_back, = ax.plot([], [], 'r')
        hl_for_mpnet, = ax.plot([], [], 'lightgreen')
        hl_back_mpnet, = ax.plot([], [], 'salmon')

        #print(obs)
        def update_line(h, ax, new_data):
            new_data = wrap_angle(new_data, propagate_system)
            h.set_data(np.append(h.get_xdata(), new_data[0]), np.append(h.get_ydata(), new_data[1]))
            #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
            #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

        def remove_last_k(h, ax, k):
            h.set_data(h.get_xdata()[:-k], h.get_ydata()[:-k])

        def draw_update_line(ax):
            #ax.relim()
            #ax.autoscale_view()
            fig.canvas.draw()
            fig.canvas.flush_events()
            #plt.show()
            
        def wrap_angle(x, system):
            circular = system.is_circular_topology()
            res = np.array(x)
            for i in range(len(x)):
                if circular[i]:
                    # use our previously saved version
                    res[i] = x[i] - np.floor(x[i] / (2*np.pi))*(2*np.pi)
                    if res[i] > np.pi:
                        res[i] = res[i] - 2*np.pi
            return res
        dx = 1
        dtheta = 0.1
        feasible_points = []
        infeasible_points = []
        imin = 0
        #imax = int(2*30./dx)
        imax = int(2*np.pi/dtheta)
        jmin = 0
        jmax = int(2*np.pi/dtheta)


        for i in range(imin, imax):
            for j in range(jmin, jmax):
                x = np.array([dtheta*i-np.pi, dtheta*j-np.pi, 0., 0.])
                if IsInCollision(x, obs_i):
                    infeasible_points.append(x)
                else:
                    feasible_points.append(x)
        feasible_points = np.array(feasible_points)
        infeasible_points = np.array(infeasible_points)
        print('feasible points')
        print(feasible_points)
        print('infeasible points')
        print(infeasible_points)
        ax.scatter(feasible_points[:,0], feasible_points[:,1], c='yellow')
        ax.scatter(infeasible_points[:,0], infeasible_points[:,1], c='pink')
        #for i in range(len(data)):
        #    update_line(hl, ax, data[i])
        draw_update_line(ax)
        #state_t = start_state

        xs_to_plot = np.array(detailed_data_path)
        for i in range(len(xs_to_plot)):
            xs_to_plot[i] = wrap_angle(detailed_data_path[i], propagate_system)
        ax.scatter(xs_to_plot[:,0], xs_to_plot[:,1], c='lightgreen', s=0.5)
        # draw start and goal
        #ax.scatter(start_state[0], goal_state[0], marker='X')
        draw_update_line(ax)
        #plt.waitforbuttonpress()

        
        if solution is not None:
            xs, us, ts = solution
            
            # propagate data
            p_start = xs[0]
            detail_paths = [p_start]
            detail_controls = []
            detail_costs = []
            state = [p_start]
            control = []
            cost = []
            for k in range(len(us)):
                #state_i.append(len(detail_paths)-1)
                print(ts[k])
                max_steps = int(np.round(ts[k]/step_sz))
                accum_cost = 0.
                #print('p_start:')
                #print(p_start)
                #print('data:')
                #print(paths[i][j][k])
                # modify it because of small difference between data and actual propagation
                for step in range(1,max_steps+1):
                    p_start = dynamics(p_start, us[k], step_sz)
                    p_start = enforce_bounds(p_start)
                    detail_paths.append(p_start)
                    accum_cost += step_sz
                    if (step % 1 == 0) or (step == max_steps):
                        state.append(p_start)
                        #print('control')
                        #print(controls[i][j])
                        cost.append(accum_cost)
                        accum_cost = 0.
            #print('p_start:')
            #print(p_start)
            #print('data:')
            #print(paths[i][j][-1])
            #state[-1] = xs[-1]
            
            
            
            xs_to_plot = np.array(state)
            for i in range(len(xs_to_plot)):
                xs_to_plot[i] = wrap_angle(xs_to_plot[i], propagate_system)
            ax.scatter(xs_to_plot[:,0], xs_to_plot[:,1], c='green', s=0.5)
            start_state_np = np.array(start_state)
            goal_state_np = np.array(goal_state)
            ax.scatter([start_state_np[0]], [start_state_np[1]], c='blue', marker='*')
            ax.scatter([goal_state_np[0]], [goal_state_np[1]], c='red', marker='*')

            # draw start and goal
            #ax.scatter(start_state[0], goal_state[0], marker='X')
            draw_update_line(ax)
            plt.waitforbuttonpress()
        """

        # validate if the path contains collision
        if solution is not None:
            res_x, res_u, res_t = solution

            print('solution_x:')
            print(res_x)
            print('path_x:')
            print(np.array(data_path))

            # propagate data
            p_start = res_x[0]
            detail_paths = [p_start]
            detail_controls = []
            detail_costs = []
            state = [p_start]
            control = []
            cost = []
            for k in range(len(res_u)):
                #state_i.append(len(detail_paths)-1)
                max_steps = int(np.round(res_t[k] / step_sz))
                accum_cost = 0.
                #print('p_start:')
                #print(p_start)
                #print('data:')
                #print(paths[i][j][k])
                # modify it because of small difference between data and actual propagation
                #p_start = res_x[k]
                #state[-1] = res_x[k]
                for step in range(1, max_steps + 1):
                    p_start = dynamics(p_start, res_u[k], step_sz)
                    p_start = enforce_bounds(p_start)
                    detail_paths.append(p_start)
                    accum_cost += step_sz
                    if (step % 1 == 0) or (step == max_steps):
                        state.append(p_start)
                        #print('control')
                        #print(controls[i][j])
                        cost.append(accum_cost)
                        accum_cost = 0.
                        # check collision for the new state
                        if IsInCollision(p_start, obs_i):
                            print(
                                'collision happens at u_index: %d, step: %d' %
                                (k, step))
                        assert not IsInCollision(p_start, obs_i)

            #print('p_start:')
            #print(p_start)
            #print('data:')
            #print(paths[i][j][-1])
            #state[-1] = res_x[-1]
        # validation end

        print('plan time: %fs' % (plan_time))
        if solution is None:
            print('failed.')
            out_queue_t.put(-1)
            out_queue_cost.put(-1)
        else:
            print('path succeeded.')
            out_queue_t.put(plan_time)
            out_queue_cost.put(t_sst)

    ####################################################################################

    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N,
                                                       args.seen_NP,
                                                       args.data_folder, obs_f,
                                                       args.seen_s,
                                                       args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            args.unseen_N, args.unseen_NP, args.data_folder, obs_f,
            args.unseen_s, args.unseen_sp)
    # test
    # testing

    queue_t = Queue(1)
    queue_cost = Queue(1)
    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.

    obc, obs, paths, sgs, path_lengths, controls, costs = seen_test_data

    obc = obc.astype(np.float32)
    # for all planning, use a flattened vector to store
    plan_times = []
    plan_res_all = []
    plan_costs = []
    data_costs = []

    # store in a 2d vector, for env and path
    plan_res_env = []
    plan_time_env = []
    plan_cost_env = []
    data_cost_env = []

    # directory to save the results
    res_path = args.res_path
    res_path = res_path + args.env_type + "_sst_compare_with_mpnet/"

    if args.env_type == 'acrobot_obs':
        res_path = '/media/arclabdl1/HD1/YLmiao/mpc-mpnet-cuda-yinglong/results/cpp_full/acrobot_obs/default_small_model_batch/'
    elif args.env_type == 'cartpole_obs':
        res_path = '/media/arclabdl1/HD1/YLmiao/mpc-mpnet-cuda-yinglong/results/cpp_full/cartpole_obs/default_small_model_batch/'

    mpnet_tree_time = np.load(res_path + 'time_10_100.npy', allow_pickle=True)
    mpnet_tree_sr = np.load(res_path + 'sr_10_100.npy', allow_pickle=True)
    mpnet_tree_cost = np.load(res_path + 'costs_10_100.npy', allow_pickle=True)
    if not os.path.exists(res_path):
        os.makedirs(res_path)

    for i in range(len(paths)):
        new_obs_i = []
        obs_i = obs[i]
        plan_res_path = []
        plan_time_path = []
        plan_cost_path = []
        data_cost_path = []
        for k in range(len(obs_i)):
            obs_pt = []
            obs_pt.append(obs_i[k][0] - obs_width / 2)
            obs_pt.append(obs_i[k][1] - obs_width / 2)
            obs_pt.append(obs_i[k][0] - obs_width / 2)
            obs_pt.append(obs_i[k][1] + obs_width / 2)
            obs_pt.append(obs_i[k][0] + obs_width / 2)
            obs_pt.append(obs_i[k][1] + obs_width / 2)
            obs_pt.append(obs_i[k][0] + obs_width / 2)
            obs_pt.append(obs_i[k][1] - obs_width / 2)
            new_obs_i.append(obs_pt)
        obs_i = new_obs_i
        #print(obs_i)
        for j in range(len(paths[i])):
            start_state = sgs[i][j][0]
            #goal_inform_state = paths[i][j][-1]
            goal_inform_state = sgs[i][j][1]
            goal_state = sgs[i][j][1]
            # propagate data
            p_start = paths[i][j][0]
            detail_paths = [p_start]
            detail_controls = []
            detail_costs = []
            state = [p_start]
            control = []
            cost = []
            for k in range(len(controls[i][j])):
                #state_i.append(len(detail_paths)-1)
                #max_steps = int(costs[i][j][k]/step_sz)
                max_steps = 1000000
                accum_cost = 0.
                for step in range(1, max_steps + 1):
                    p_start = dynamics(p_start, controls[i][j][k], step_sz)
                    p_start = enforce_bounds(p_start)
                    detail_paths.append(p_start)
                    accum_cost += step_sz
                    if (step % 1 == 0) or (step == max_steps):
                        state.append(p_start)
                        #print('control')
                        #print(controls[i][j])
                        cost.append(accum_cost)
                        accum_cost = 0.
                        # check collision for the new state
                        if IsInCollision(p_start, obs_i):
                            print(
                                'collision happens at u_index: %d, step: %d' %
                                (k, step))
                        assert not IsInCollision(p_start, obs_i)
                    if np.linalg.norm(p_start - paths[i][j][k + 1]) <= 1e-3:
                        break
        # validation end

            cost_i = np.sum(cost)
            print('data cost:', cost_i)
            if mpnet_tree_sr[i][j] != 0:
                # use MPNet tree cost
                cost_i = mpnet_tree_cost[i][j]
                print('using mpnet cost: ', cost_i)

            #cost_i = 100000000.
            # acrobot: 300000
            # cartpole: 500000
            print('environment: %d/%d, path: %d/%d' %
                  (i + 1, len(paths), j + 1, len(paths[i])))
            plan_t_trials = []
            plan_cost_trials = []
            for trial in range(1):
                random_seed = random.randint(0, 100)
                #random_seed = 0
                p = Process(target=plan_one_path,
                            args=(obs_i, obs[i], obc[i], state, paths[i][j],
                                  start_state, goal_state, goal_inform_state,
                                  cost_i, args.num_iter, queue_t, queue_cost,
                                  random_seed))
                #plan_one_path(obs_i, obs[i], obc[i], state, paths[i][j], start_state, goal_state, goal_inform_state, cost_i, args.num_iter, queue_t, queue_cost, random_seed)
                p.start()
                p.join()
                plan_t = queue_t.get()
                plan_cost = queue_cost.get()
                if plan_t != -1:
                    plan_t_trials.append(plan_t)
                    plan_cost_trials.append(plan_cost)
            #assert len(plan_ts) == 10

            plan_t = np.mean(plan_t_trials)
            plan_cost = np.mean(plan_cost_trials)

            if plan_t == -1:
                # failed, do not record in the flattened list
                plan_res_all.append(0)
                # record in the 2d list
                plan_res_path.append(0)
                plan_time_path.append(plan_t)
                plan_cost_path.append(plan_cost)
                data_cost_path.append(-1.0)
            else:
                # record in the flattened list
                plan_res_all.append(1)
                plan_times.append(plan_t)
                plan_costs.append(plan_cost)
                data_costs.append(cost_i)
                # record in the 2d list
                plan_res_path.append(1)
                plan_time_path.append(plan_t)
                plan_cost_path.append(plan_cost)
                data_cost_path.append(cost_i)
            print('plan costs:')
            print(plan_costs)
            print('average accuracy up to now: %f' %
                  (np.array(plan_res_all).flatten().mean()))
            print('plan average time: %f' % (np.array(plan_times).mean()))
            print('plan time std: %f' % (np.array(plan_times).std()))
            print('plan average cost: %f' % (np.array(plan_costs).mean()))
            print('plan cost std: %f' % (np.array(plan_costs).std()))
            print('data average cost: %f' % (np.array(data_costs).mean()))
            print('data cost std: %f' % (np.array(data_costs).std()))

        # store in the 2d list
        plan_res_env.append(plan_res_path)
        plan_time_env.append(plan_time_path)
        plan_cost_env.append(plan_cost_path)
        data_cost_env.append(data_cost_path)

        # for every environment planned, save
        # save the 2d list
        # save as numpy array
        #np.save(res_path+"plan_res.npy", np.array(plan_res_env))
        #np.save(res_path+"plan_time.npy", np.array(plan_time_env))
        #np.save(res_path+"plan_cost.npy", np.array(plan_cost_env))
        #np.save(res_path+"data_cost.npy", np.array(data_cost_env))

    print('plan accuracy: %f' % (np.array(plan_res_all).flatten().mean()))
    print('plan average time: %f' % (np.array(plan_times).mean()))
    print('plan time std: %f' % (np.array(plan_times).std()))
    print('plan average cost: %f' % (np.array(plan_costs).mean()))
    print('plan cost std: %f' % (np.array(plan_costs).std()))
    print('data average cost: %f' % (np.array(data_costs).mean()))
    print('data cost std: %f' % (np.array(data_costs).std()))

    # save the 2d list
    # save as numpy array
    plan_res_env = np.array(plan_res_env)
    plan_time_env = np.array(plan_time_env)
    plan_cost_env = np.array(plan_cost_env)
    data_cost_env = np.array(data_cost_env)

    np.save(res_path + "plan_res.npy", plan_res_env)
    np.save(res_path + "plan_time.npy", plan_time_env)
    np.save(res_path + "plan_cost.npy", plan_cost_env)
    np.save(res_path + "data_cost.npy", data_cost_env)
Example #4
0
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        IsInCollision = pendulum.IsInCollision
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        obs_file = None
        obc_file = None
        dynamics = pendulum.dynamics
        jax_dynamics = pendulum.jax_dynamics
        enforce_bounds = pendulum.enforce_bounds
        cae = cae_identity
        mlp = MLP
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        IsInCollision = cartpole.IsInCollision
        normalize = cartpole.normalize
        unnormalize = cartpole.unnormalize
        obs_file = None
        obc_file = None
        dynamics = cartpole.dynamics
        jax_dynamics = cartpole.jax_dynamics
        enforce_bounds = cartpole.enforce_bounds
        cae = CAE_acrobot_voxel_2d
        mlp = mlp_acrobot.MLP
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'cartpole')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs':
        IsInCollision = acrobot_obs.IsInCollision
        #IsInCollision = lambda x, obs: False
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        xdot = acrobot_obs.dynamics
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d
        mlp = mlp_acrobot.MLP
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 50, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_2':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        xdot = acrobot_obs.dynamics
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d_2
        mlp = mlp_acrobot.MLP2
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_3':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        xdot = acrobot_obs.dynamics
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_2
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_5':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        xdot = acrobot_obs.dynamics
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d_3
        mlp = mlp_acrobot.MLP
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0
    elif args.env_type == 'acrobot_obs_6':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        xdot = acrobot_obs.dynamics
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d_3
        mlp = mlp_acrobot.MLP4
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_6':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        xdot = acrobot_obs.dynamics
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        mlp = mlp_acrobot.MLP5
        cae = CAE_acrobot_voxel_2d_3
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_8':
        IsInCollision = acrobot_obs.IsInCollision
        #IsInCollision = lambda x, obs: False
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        xdot = acrobot_obs.dynamics
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d_3
        mlp = mlp_acrobot.MLP6
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        #num_steps = 21
        num_steps = 21  #args.num_steps*2
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        #traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init:
        #def cem_trajopt(x0, x1, step_sz, num_steps, x_init, u_init, t_init):
        #    u, t = acrobot_obs.trajopt(x0, x1, 500, num_steps, step_sz*1, step_sz*(num_steps-1), x_init, u_init, t_init)
        #    xs, us, dts, valid = propagate(x0, u, t, dynamics=dynamics, enforce_bounds=enforce_bounds, IsInCollision=lambda x: False, system=system, step_sz=step_sz)
        #    return xs, us, dts
        #traj_opt = cem_trajopt
        goal_S0 = np.diag([1., 1., 0, 0])
        goal_rho0 = 1.0

    mpNet0 = KMPNet(args.total_input_size, args.AE_input_size,
                    args.mlp_input_size, args.output_size, cae, mlp)
    mpNet1 = KMPNet(args.total_input_size, args.AE_input_size,
                    args.mlp_input_size, args.output_size, cae, mlp)

    # load previously trained model if start epoch > 0
    #model_path='kmpnet_epoch_%d_direction_0_step_%d.pkl' %(args.start_epoch, args.num_steps)
    model_path = 'kmpnet_epoch_%d_direction_0.pkl' % (args.start_epoch)
    if args.start_epoch > 0:
        load_net_state(mpNet0, os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(args.model_path, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)
    if torch.cuda.is_available():
        mpNet0.cuda()
        mpNet0.mlp.cuda()
        mpNet0.encoder.cuda()
        if args.opt == 'Adagrad':
            mpNet0.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpNet0.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpNet0.set_opt(torch.optim.SGD,
                           lr=args.learning_rate,
                           momentum=0.9)
    if args.start_epoch > 0:
        load_opt_state(mpNet0, os.path.join(args.model_path, model_path))

    # load previously trained model if start epoch > 0
    #model_path='kmpnet_epoch_%d_direction_1_step_%d.pkl' %(args.start_epoch, args.num_steps)
    model_path = 'kmpnet_epoch_%d_direction_1.pkl' % (args.start_epoch)
    if args.start_epoch > 0:
        load_net_state(mpNet1, os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(args.model_path, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)
    if torch.cuda.is_available():
        mpNet1.cuda()
        mpNet1.mlp.cuda()
        mpNet1.encoder.cuda()
        if args.opt == 'Adagrad':
            mpNet1.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpNet1.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpNet1.set_opt(torch.optim.SGD,
                           lr=args.learning_rate,
                           momentum=0.9)
    if args.start_epoch > 0:
        load_opt_state(mpNet1, os.path.join(args.model_path, model_path))

    # define informer
    circular = system.is_circular_topology()

    def informer(env, x0, xG, direction):
        x0_x = torch.from_numpy(x0.x).type(torch.FloatTensor)
        xG_x = torch.from_numpy(xG.x).type(torch.FloatTensor)
        x0_x = normalize_func(x0_x)
        xG_x = normalize_func(xG_x)
        if torch.cuda.is_available():
            x0_x = x0_x.cuda()
            xG_x = xG_x.cuda()
        if direction == 0:
            x = torch.cat([x0_x, xG_x], dim=0)
            mpNet = mpNet0
            if torch.cuda.is_available():
                x = x.cuda()
            next_state = mpNet(x.unsqueeze(0), env.unsqueeze(0)).cpu().data
            next_state = unnormalize_func(next_state).numpy()[0]
            delta_x = next_state - x0.x
            # can be either clockwise or counterclockwise, take shorter one
            for i in range(len(delta_x)):
                if circular[i]:
                    delta_x[i] = delta_x[i] - np.floor(
                        delta_x[i] / (2 * np.pi)) * (2 * np.pi)
                    if delta_x[i] > np.pi:
                        delta_x[i] = delta_x[i] - 2 * np.pi
                    # randomly pick either direction
                    rand_d = np.random.randint(2)
                    if rand_d < 1 and np.abs(delta_x[i]) >= np.pi * 0.5:
                        if delta_x[i] > 0.:
                            delta_x[i] = delta_x[i] - 2 * np.pi
                        if delta_x[i] <= 0.:
                            delta_x[i] = delta_x[i] + 2 * np.pi

            res = Node(x0.x + delta_x)
            cov = np.diag([0.02, 0.02, 0.02, 0.02])
            #mean = next_state
            #next_state = np.random.multivariate_normal(mean=next_state,cov=cov)
            mean = np.zeros(next_state.shape)
            rand_x_init = np.random.multivariate_normal(mean=mean,
                                                        cov=cov,
                                                        size=num_steps)
            rand_x_init[0] = rand_x_init[0] * 0.
            rand_x_init[-1] = rand_x_init[-1] * 0.

            x_init = np.linspace(x0.x, x0.x + delta_x, num_steps) + rand_x_init
            ## TODO: : change this to general case
            u_init_i = np.random.uniform(low=[-4.],
                                         high=[4],
                                         size=(num_steps, 1))
            u_init = u_init_i
            #u_init_i = control[max_d_i]
            cost_i = (num_steps - 1) * step_sz  #TOEDIT
            #u_init = np.repeat(u_init_i, num_steps, axis=0).reshape(-1,len(u_init_i))
            #u_init = u_init + np.random.normal(scale=1., size=u_init.shape)
            t_init = np.linspace(0, cost_i, num_steps)
            """
            print('init:')
            print('x_init:')
            print(x_init)
            print('u_init:')
            print(u_init)
            print('t_init:')
            print(t_init)
            print('xw:')
            print(next_state)
            """
        else:
            x = torch.cat([x0_x, xG_x], dim=0)
            mpNet = mpNet1
            next_state = mpNet(x.unsqueeze(0), env.unsqueeze(0)).cpu().data
            next_state = unnormalize_func(next_state).numpy()[0]
            delta_x = next_state - x0.x
            # can be either clockwise or counterclockwise, take shorter one
            for i in range(len(delta_x)):
                if circular[i]:
                    delta_x[i] = delta_x[i] - np.floor(
                        delta_x[i] / (2 * np.pi)) * (2 * np.pi)
                    if delta_x[i] > np.pi:
                        delta_x[i] = delta_x[i] - 2 * np.pi
                    # randomly pick either direction
                    rand_d = np.random.randint(2)
                    if rand_d < 1 and np.abs(delta_x[i]) >= np.pi * 0.5:
                        if delta_x[i] > 0.:
                            delta_x[i] = delta_x[i] - 2 * np.pi
                        elif delta_x[i] <= 0.:
                            delta_x[i] = delta_x[i] + 2 * np.pi
            #next_state = state[max_d_i] + delta_x
            next_state = x0.x + delta_x
            res = Node(next_state)
            # initial: from max_d_i to max_d_i+1
            x_init = np.linspace(next_state, x0.x, num_steps) + rand_x_init
            # action: copy over to number of steps
            u_init_i = np.random.uniform(low=[-4.],
                                         high=[4],
                                         size=(num_steps, 1))
            u_init = u_init_i
            cost_i = (num_steps - 1) * step_sz
            #u_init = np.repeat(u_init_i, num_steps, axis=0).reshape(-1,len(u_init_i))
            #u_init = u_init + np.random.normal(scale=1., size=u_init.shape)
            t_init = np.linspace(0, cost_i, num_steps)
        return res, x_init, u_init, t_init

    def init_informer(env, x0, xG, direction):
        if direction == 0:
            next_state = xG.x
            delta_x = next_state - x0.x

            # can be either clockwise or counterclockwise, take shorter one
            for i in range(len(delta_x)):
                if circular[i]:
                    delta_x[i] = delta_x[i] - np.floor(
                        delta_x[i] / (2 * np.pi)) * (2 * np.pi)
                    if delta_x[i] > np.pi:
                        delta_x[i] = delta_x[i] - 2 * np.pi
                    # randomly pick either direction
                    rand_d = np.random.randint(2)
                    #print('inside init_informer')
                    #print('delta_x[%d]: %f' % (i, delta_x[i]))
                    if rand_d < 1 and np.abs(delta_x[i]) >= np.pi * 0.9:
                        if delta_x[i] > 0.:
                            delta_x[i] = delta_x[i] - 2 * np.pi
                        if delta_x[i] <= 0.:
                            delta_x[i] = delta_x[i] + 2 * np.pi
            res = Node(next_state)
            cov = np.diag([0.02, 0.02, 0.02, 0.02])
            #mean = next_state
            #next_state = np.random.multivariate_normal(mean=next_state,cov=cov)
            mean = np.zeros(next_state.shape)
            rand_x_init = np.random.multivariate_normal(mean=mean,
                                                        cov=cov,
                                                        size=num_steps)
            rand_x_init[0] = rand_x_init[0] * 0.
            rand_x_init[-1] = rand_x_init[-1] * 0.

            x_init = np.linspace(x0.x, x0.x + delta_x, num_steps) + rand_x_init
            ## TODO: : change this to general case
            u_init_i = np.random.uniform(low=[-4.],
                                         high=[4],
                                         size=(num_steps, 1))
            u_init = u_init_i
            #u_init_i = control[max_d_i]
            #cost_i = 10*step_sz
            cost_i = (num_steps - 1) * step_sz

            #u_init = np.repeat(u_init_i, num_steps, axis=0).reshape(-1,len(u_init_i))
            #u_init = u_init + np.random.normal(scale=1., size=u_init.shape)
            t_init = np.linspace(0, cost_i, num_steps)

        else:
            next_state = xG.x
            delta_x = x0.x - next_state
            # can be either clockwise or counterclockwise, take shorter one
            for i in range(len(delta_x)):
                if circular[i]:
                    delta_x[i] = delta_x[i] - np.floor(
                        delta_x[i] / (2 * np.pi)) * (2 * np.pi)
                    if delta_x[i] > np.pi:
                        delta_x[i] = delta_x[i] - 2 * np.pi
                    # randomly pick either direction
                    rand_d = np.random.randint(2)
                    if rand_d < 1 and np.abs(delta_x[i]) >= np.pi * 0.5:
                        if delta_x[i] > 0.:
                            delta_x[i] = delta_x[i] - 2 * np.pi
                        elif delta_x[i] <= 0.:
                            delta_x[i] = delta_x[i] + 2 * np.pi
            #next_state = state[max_d_i] + delta_x
            res = Node(next_state)
            # initial: from max_d_i to max_d_i+1
            x_init = np.linspace(next_state, next_state + delta_x,
                                 num_steps) + rand_x_init
            # action: copy over to number of steps
            u_init_i = np.random.uniform(low=[-4.],
                                         high=[4],
                                         size=(num_steps, 1))
            u_init = u_init_i
            cost_i = (num_steps - 1) * step_sz
            #u_init = np.repeat(u_init_i, num_steps, axis=0).reshape(-1,len(u_init_i))
            #u_init = u_init + np.random.normal(scale=1., size=u_init.shape)
            t_init = np.linspace(0, cost_i, num_steps)
        return x_init, u_init, t_init

    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N,
                                                       args.seen_NP,
                                                       args.data_folder, obs_f,
                                                       args.seen_s,
                                                       args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            args.unseen_N, args.unseen_NP, args.data_folder, obs_f,
            args.unseen_s, args.unseen_sp)
    # test
    # testing

    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.
    T = 1
    for _ in range(T):
        # unnormalize function
        normalize_func = lambda x: normalize(x, args.world_size)
        unnormalize_func = lambda x: unnormalize(x, args.world_size)
        # seen
        if args.seen_N > 0:
            time_file = os.path.join(
                args.model_path,
                'time_seen_epoch_%d_mlp.p' % (args.start_epoch))
            fes_path_, valid_path_ = eval_tasks(
                mpNet0, mpNet1, seen_test_data, args.model_path, time_file,
                IsInCollision, normalize_func, unnormalize_func, informer,
                init_informer, system, dynamics, xdot, jax_dynamics,
                enforce_bounds, traj_opt, step_sz, num_steps)
            valid_path = valid_path_.flatten()
            fes_path = fes_path_.flatten(
            )  # notice different environments are involved
            seen_test_suc_rate += fes_path.sum() / valid_path.sum()
        # unseen
        if args.unseen_N > 0:
            time_file = os.path.join(
                args.model_path,
                'time_unseen_epoch_%d_mlp.p' % (args.start_epoch))
            fes_path_, valid_path_ = eval_tasks(
                mpNet0, mpNet1, unseen_test_data, args.model_path, time_file,
                IsInCollision, normalize_func, unnormalize_func, informer,
                init_informer, system, dynamics, xdot, jax_dynamics,
                enforce_bounds, traj_opt, step_sz, num_steps)
            valid_path = valid_path_.flatten()
            fes_path = fes_path_.flatten(
            )  # notice different environments are involved
            unseen_test_suc_rate += fes_path.sum() / valid_path.sum()
    if args.seen_N > 0:
        seen_test_suc_rate = seen_test_suc_rate / T
        f = open(
            os.path.join(args.model_path,
                         'seen_accuracy_epoch_%d.txt' % (args.start_epoch)),
            'w')
        f.write(str(seen_test_suc_rate))
        f.close()
    if args.unseen_N > 0:
        unseen_test_suc_rate = unseen_test_suc_rate / T  # Save the models
        f = open(
            os.path.join(args.model_path,
                         'unseen_accuracy_epoch_%d.txt' % (args.start_epoch)),
            'w')
        f.write(str(unseen_test_suc_rate))
        f.close()
def main(args):
    # load MPNet
    #global hl
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
        
    if args.debug:
        from sparse_rrt import _sst_module
        from plan_utility import cart_pole, cart_pole_obs, pendulum, acrobot_obs
        from tools import data_loader

        cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        if args.debug:
            normalize = pendulum.normalize
            unnormalize = pendulum.unnormalize
            system = standard_cpp_systems.PSOPTPendulum()
            dynamics = None
            enforce_bounds = None
            step_sz = 0.002
            num_steps = 20

    elif args.env_type == 'cartpole':
        if args.debug:
            normalize = cart_pole.normalize
            unnormalize = cart_pole.unnormalize
            dynamics = cartpole.dynamics
            system = _sst_module.CartPole()
            enforce_bounds = cartpole.enforce_bounds
            step_sz = 0.002
            num_steps = 20
    elif args.env_type == 'cartpole_obs':
        if args.debug:
            normalize = cart_pole_obs.normalize
            unnormalize = cart_pole_obs.unnormalize
            system = _sst_module.PSOPTCartPole()

            dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
            enforce_bounds = cart_pole_obs.enforce_bounds
            step_sz = 0.002
            num_steps = 20
        mlp = mlp_cartpole.MLP
        cae = CAE_cartpole_voxel_2d
    elif args.env_type == 'acrobot_obs':
        if args.debug:
            normalize = acrobot_obs.normalize
            unnormalize = acrobot_obs.unnormalize
            system = _sst_module.PSOPTAcrobot()
            #dynamics = acrobot_obs.dynamics
            dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
            enforce_bounds = acrobot_obs.enforce_bounds
            step_sz = 0.02
            num_steps = 20
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d

    if args.loss == 'mse':
        loss_f = nn.MSELoss()
        #loss_f = mse_loss

    elif args.loss == 'l1_smooth':
        loss_f = nn.SmoothL1Loss()
        #loss_f = l1_smooth_loss

    elif args.loss == 'mse_decoupled':
        def mse_decoupled(y1, y2):
            # for angle terms, wrap it to -pi~pi
            l_0 = torch.abs(y1[:,0] - y2[:,0])
            l_1 = torch.abs(y1[:,1] - y2[:,1])
            l_2 = torch.abs(y1[:,2] - y2[:,2]) # angular dimension
            l_3 = torch.abs(y1[:,3] - y2[:,3])
            cond = l_2 > np.pi
            l_2 = torch.where(cond, 2*np.pi-l_2, l_2)
            l_0 = torch.mean(l_0)
            l_1 = torch.mean(l_1)
            l_2 = torch.mean(l_2)
            l_3 = torch.mean(l_3)
            return torch.stack([l_0, l_1, l_2, l_3])
        loss_f = mse_decoupled


    mpnet_pnet = KMPNet(args.total_input_size, args.AE_input_size, args.mlp_input_size, args.output_size // 2,
                   cae, mlp, loss_f)
    mpnet_vnet = KMPNet(args.total_input_size, args.AE_input_size, args.mlp_input_size, args.output_size // 2,
                   cae, mlp, loss_f)

    mpnet_pos_vel = PosVelKMPNet(mpnet_p, mpnet_v)
    # load net
    # load previously trained model if start epoch > 0
    model_dir = args.model_dir
    if args.loss == 'mse':
        if args.multigoal == 0:
            model_dir = model_dir+args.env_type+"_lr%f_%s_step_%d/" % (args.learning_rate, args.opt, args.num_steps)
        else:
            model_dir = model_dir+args.env_type+"_lr%f_%s_step_%d_multigoal/" % (args.learning_rate, args.opt, args.num_steps)
    else:
        if args.multigoal == 0:
            model_dir = model_dir+args.env_type+"_lr%f_%s_loss_%s_step_%d/" % (args.learning_rate, args.opt, args.loss, args.num_steps)
        else:
            model_dir = model_dir+args.env_type+"_lr%f_%s_loss_%s_step_%d_multigoal/" % (args.learning_rate, args.opt, args.loss, args.num_steps)
    
    
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_pnet_path='kmpnet_pnet_epoch_%d_direction_%d_step_%d.pkl' %(args.start_epoch, args.direction, args.num_steps)
    model_vnet_path='kmpnet_vnet_epoch_%d_direction_%d_step_%d.pkl' %(args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet_p, os.path.join(model_dir, model_pnet_path))
        load_net_state(mpnet_v, os.path.join(model_dir, model_vnet_path))

        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(os.path.join(model_dir, model_pnet_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet_pnet.cuda()
        mpnet_pnet.mlp.cuda()
        mpnet_pnet.encoder.cuda()

        mpnet_vnet.cuda()
        mpnet_vnet.mlp.cuda()
        mpnet_vnet.encoder.cuda()


    # load train and test data
    print('loading...')
    if args.debug:
        obs, cost_dataset, cost_targets, env_indices, \
        _, _, _, _ = data_loader.load_train_dataset_cost(N=args.no_env, NP=args.no_motion_paths,
                                                    data_folder=args.path_folder, obs_f=True,
                                                    direction=args.direction,
                                                    dynamics=dynamics, enforce_bounds=enforce_bounds,
                                                    system=system, step_sz=step_sz, num_steps=args.num_steps)
        # randomize the dataset before training
        data=list(zip(cost_dataset,cost_targets,env_indices))
        random.shuffle(data)
        dataset,targets,env_indices=list(zip(*data))
        dataset = list(dataset)
        targets = list(targets)
        env_indices = list(env_indices)
        dataset = np.array(dataset)
        targets = np.array(targets)
        env_indices = np.array(env_indices)
        # record
        bi = dataset.astype(np.float32)
        print('bi shape:')
        print(bi.shape)
        bt = targets
        bi = torch.FloatTensor(bi)
        bt = torch.FloatTensor(bt)
        bi = normalize(bi, args.world_size)
        bi=to_var(bi)
        bt=to_var(bt)
        if obs is None:
            bobs = None
        else:
            bobs = obs[env_indices].astype(np.float32)
            bobs = torch.FloatTensor(bobs)
            bobs = to_var(bobs)
    else:
        bobs = np.random.rand(1,1,args.AE_input_size,args.AE_input_size)
        bobs = torch.from_numpy(bobs).type(torch.FloatTensor)
        bobs = to_var(bobs)
        bi = np.random.rand(1, args.total_input_size)
        bt = np.random.rand(1, args.output_size)
        bi = torch.from_numpy(bi).type(torch.FloatTensor)
        bt = torch.from_numpy(bt).type(torch.FloatTensor)
        bi = to_var(bi)
        bt = to_var(bt)
    # set to training model to enable dropout
    mpnet.train()
    #mpnet.eval()

    MLP = mpnet.mlp
    encoder = mpnet.encoder
    traced_encoder = torch.jit.trace(encoder, (bobs))
    encoder_output = encoder(bobs)
    mlp_input = torch.cat((encoder_output, bi), 1)
    traced_MLP = torch.jit.trace(MLP, (mlp_input))
    traced_encoder.save('%s_encoder_lr%f_epoch_%d_step_%d.pt' % (args.env_type, args.learning_rate, args.start_epoch, args.num_steps))
    traced_MLP.save('%s_MLP_lr%f_epoch_%d_step_%d.pt' % (args.env_type, args.learning_rate, args.start_epoch, args.num_steps))

    #traced_encoder.save("%s_encoder_epoch_%d.pt" % (args.env_type, args.start_epoch))
    #traced_MLP.save("%s_MLP_epoch_%d.pt" % (args.env_type, args.start_epoch))

    # test the traced model
    serilized_encoder = torch.jit.script(encoder)
    serilized_MLP = torch.jit.script(MLP)
    serilized_encoder_output = serilized_encoder(bobs)
    serilized_MLP_input = torch.cat((serilized_encoder_output, bi), 1)
    serilized_MLP_output = serilized_MLP(serilized_MLP_input)
    print('encoder output: ', serilized_encoder_output)
    print('MLP output: ', serilized_MLP_output)
    print('data: ', bt)
Example #6
0
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    #torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    #if torch.cuda.is_available():
    #    torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        obs_file = None
        obc_file = None
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        normalize = cartpole.normalize
        unnormalize = cartpole.unnormalize
        obs_file = None
        obc_file = None
        #dynamics = cartpole.dynamics
        #jax_dynamics = cartpole.jax_dynamics
        #enforce_bounds = cartpole.enforce_bounds
        cae = CAE_acrobot_voxel_2d
        mlp = mlp_acrobot.MLP
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'cartpole')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs':
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 20
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(x0, x1, 200, num_steps, step_sz*1, step_sz*(num_steps-1), x_init, u_init, t_init)
        obs_width = 6.0
        step_sz = 0.02
        num_steps = 20
        goal_radius=2.0
        random_seed=0
        delta_near=0.1
        delta_drain=0.05


    elif args.env_type in ['acrobot_obs','acrobot_obs_2', 'acrobot_obs_3', 'acrobot_obs_4', 'acrobot_obs_8']:
        #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
        obs_width = 6.0
        step_sz = 0.02
        num_steps = 20
        goal_radius=2.0
        random_seed=0
        delta_near=0.1
        delta_drain=0.05

    # load previously trained model if start epoch > 0
    #model_path='kmpnet_epoch_%d_direction_0_step_%d.pkl' %(args.start_epoch, args.num_steps)
    mlp_path = os.path.join(os.getcwd()+'/c++/','acrobot_obs_MLP_lr0.010000_epoch_2850_step_20.pt')
    encoder_path = os.path.join(os.getcwd()+'/c++/','acrobot_obs_encoder_lr0.010000_epoch_2850_step_20.pt')
    cost_mlp_path = os.path.join(os.getcwd()+'/c++/','costnet_acrobot_obs_8_MLP_epoch_300_step_20.pt')
    cost_encoder_path = os.path.join(os.getcwd()+'/c++/','costnet_acrobot_obs_8_encoder_epoch_300_step_20.pt')

    print('mlp_path:')
    print(mlp_path)
    #####################################################
    def plan_one_path(obs_i, obs, obc, start_state, goal_state, goal_inform_state, max_iteration, data, out_queue):
        if args.env_type == 'pendulum':
            system = standard_cpp_systems.PSOPTPendulum()
            bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
            step_sz = 0.002
            num_steps = 20
            traj_opt = lambda x0, x1: bvp_solver.solve(x0, x1, 200, num_steps, 1, 20, step_sz)

        elif args.env_type == 'cartpole_obs':
            #system = standard_cpp_systems.RectangleObs(obs[i], 4.0, 'cartpole')
            system = _sst_module.CartPole()
            bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
            step_sz = 0.002
            num_steps = 20
            traj_opt = lambda x0, x1, x_init, u_init, t_init: bvp_solver.solve(x0, x1, 200, num_steps, step_sz*1, step_sz*50, x_init, u_init, t_init)
            goal_S0 = np.identity(4)
            goal_rho0 = 1.0
        elif args.env_type in ['acrobot_obs','acrobot_obs_2', 'acrobot_obs_3', 'acrobot_obs_4', 'acrobot_obs_8']:
            #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
            obs_width = 6.0
            psopt_system = _sst_module.PSOPTAcrobot()
            propagate_system = standard_cpp_systems.RectangleObs(obs, 6., 'acrobot')
            distance_computer = propagate_system.distance_computer()
            #distance_computer = _sst_module.euclidean_distance(np.array(propagate_system.is_circular_topology()))
            bvp_wrapper = _sst_module.PSOPTBVPWrapper(psopt_system, 4, 1, 0)
            step_sz = 0.02
            num_steps = 20
            psopt_num_steps = 20
            psopt_step_sz = 0.02
            goal_radius=2
            random_seed=0
            #delta_near=1.0
            #delta_drain=0.5
            delta_near=0.1
            delta_drain=0.05
        #print('creating planner...')
        planner = vis_planners.DeepSMPWrapper(mlp_path, encoder_path, 
                                              cost_mlp_path, cost_encoder_path, 
                                              20, psopt_num_steps+1, psopt_step_sz, step_sz, propagate_system, args.device)
        # generate a path by using SST to plan for some maximal iterations
        time0 = time.time()
        #print('obc:')
        #print(obc.shape)
        #print(delta_near)
        #print(delta_drain)
        #print('start_state:')
        #print(start_state)
        #print('goal_state:')
        #print(goal_state)

        plt.ion()
        fig = plt.figure()
        ax = fig.add_subplot(111)
        #ax.set_autoscale_on(True)
        ax.set_xlim(-np.pi, np.pi)
        ax.set_ylim(-np.pi, np.pi)
        hl, = ax.plot([], [], 'b')
        #hl_real, = ax.plot([], [], 'r')
        hl_for, = ax.plot([], [], 'g')
        hl_back, = ax.plot([], [], 'r')
        hl_for_mpnet, = ax.plot([], [], 'lightgreen')
        hl_back_mpnet, = ax.plot([], [], 'salmon')
        
        #print(obs)
        def update_line(h, ax, new_data):
            new_data = wrap_angle(new_data, propagate_system)
            h.set_data(np.append(h.get_xdata(), new_data[0]), np.append(h.get_ydata(), new_data[1]))
            #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
            #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

        def remove_last_k(h, ax, k):
            h.set_data(h.get_xdata()[:-k], h.get_ydata()[:-k])

        def draw_update_line(ax):
            #ax.relim()
            #ax.autoscale_view()
            fig.canvas.draw()
            fig.canvas.flush_events()
            #plt.show()

        def wrap_angle(x, system):
            circular = system.is_circular_topology()
            res = np.array(x)
            for i in range(len(x)):
                if circular[i]:
                    # use our previously saved version
                    res[i] = x[i] - np.floor(x[i] / (2*np.pi))*(2*np.pi)
                    if res[i] > np.pi:
                        res[i] = res[i] - 2*np.pi
            return res
        dtheta = 0.1
        feasible_points = []
        infeasible_points = []
        imin = 0
        imax = int(2*np.pi/dtheta)
        circular = psopt_system.is_circular_topology()


        for i in range(imin, imax):
            for j in range(imin, imax):
                x = np.array([dtheta*i-np.pi, dtheta*j-np.pi, 0., 0.])
                if IsInCollision(x, obs_i):
                    infeasible_points.append(x)
                else:
                    feasible_points.append(x)
        feasible_points = np.array(feasible_points)
        infeasible_points = np.array(infeasible_points)
        print('feasible points')
        print(feasible_points)
        print('infeasible points')
        print(infeasible_points)
        ax.scatter(feasible_points[:,0], feasible_points[:,1], c='yellow')
        ax.scatter(infeasible_points[:,0], infeasible_points[:,1], c='pink')
        #for i in range(len(data)):
        #    update_line(hl, ax, data[i])
        
        data = np.array(data)
        ax.scatter(data[:,0], data[:,1], c='lightblue', s=10)
        ax.scatter(data[-1,0], data[-1,1], c='red', s=10, marker='*')

        draw_update_line(ax)
        state_t = start_state

        state_t = data[0]
        for data_i in range(0,len(data),num_steps):
            print('iteration: %d' % (data_i))
            print('state_t:')
            print(state_t)    

            
            min_dis_to_goal = 100000.
            min_xs_to_plot = []
            for trials in range(10):
                x_init, u_init, t_init = init_informer(propagate_system, state_t, data[data_i], psopt_num_steps+1, psopt_step_sz)
                print('x_init:')
                print(x_init)

                bvp_x, bvp_u, bvp_t = bvp_wrapper.solve(state_t, x_init[-1], 20, psopt_num_steps+1, 0.8*psopt_step_sz*psopt_num_steps, 2*psopt_step_sz*psopt_num_steps, \
                                                        x_init, u_init, t_init)
                print('bvp_x:')
                print(bvp_x)
                print('bvp_u:')
                print(bvp_u)
                print('bvp_t:')
                print(bvp_t)
                if len(bvp_u) != 0:# and bvp_t[0] > 0.01:  # turn bvp_t off if want to use step_bvp
                    # propagate data
                    #p_start = bvp_x[0]
                    p_start = state_t
                    detail_paths = [p_start]
                    detail_controls = []
                    detail_costs = []
                    state = [p_start]
                    control = []
                    cost = []
                    for k in range(len(bvp_t)):
                        #state_i.append(len(detail_paths)-1)
                        max_steps = int(np.round(bvp_t[k]/step_sz))
                        accum_cost = 0.
                        for step in range(1,max_steps+1):
                            p_start = dynamics(p_start, bvp_u[k], step_sz)
                            p_start = enforce_bounds(p_start)
                            detail_paths.append(p_start)
                            accum_cost += step_sz
                            if (step % 1 == 0) or (step == max_steps):
                                state.append(p_start)
                                cost.append(accum_cost)
                                accum_cost = 0.

                    xs_to_plot = np.array(state)
                    
                    for i in range(len(xs_to_plot)):
                        xs_to_plot[i] = wrap_angle(xs_to_plot[i], propagate_system)
                    delta_x = xs_to_plot[-1] - data[data_i]
                    for i in range(len(delta_x)):
                        if circular[i]:
                            delta_x[i] = delta_x[i] - np.floor(delta_x[i] / (2*np.pi))*(2*np.pi)
                            if delta_x[i] > np.pi:
                                delta_x[i] = delta_x[i] - 2*np.pi
                    dis = np.linalg.norm(delta_x)
                    if dis <= min_dis_to_goal:
                        min_dis_to_goal = dis
                        min_xs_to_plot = xs_to_plot

            #ax.scatter(xs_to_plot[:,0], xs_to_plot[:,1], c='green')
            ax.scatter(min_xs_to_plot[:,0], min_xs_to_plot[:,1], c='green', s=10.0)

            # draw start and goal
            #ax.scatter(start_state[0], goal_state[0], marker='X')
            draw_update_line(ax)
            #state_t = min_xs_to_plot[-1]
            # try using mpnet_res as new start

            state_t = data[data_i]




            #state_t = min_xs_to_plot[-1]
            print('data_i:')

            print(data[data_i])
            #else:
            #    # in incollision
            #    state_t = data[data_i]
        #if len(res_x) == 0:
        #    print('failed.')
        out_queue.put(0)
        #else:
        #    print('path succeeded.')
        #    out_queue.put(1)
    ####################################################################################



    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N, args.seen_NP,
                                  args.data_folder, obs_f, args.seen_s, args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(args.unseen_N, args.unseen_NP,
                                  args.data_folder, obs_f, args.unseen_s, args.unseen_sp)
    # test
    # testing

    queue = Queue(1)
    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.

    obc, obs, paths, sgs, path_lengths, controls, costs = seen_test_data
    obc = obc.astype(np.float32)
    #obc = torch.from_numpy(obc)
    #if torch.cuda.is_available():
    #    obc = obc.cuda()
    for i in range(len(paths)):
        new_obs_i = []
        obs_i = obs[i]
        for k in range(len(obs_i)):
            obs_pt = []
            obs_pt.append(obs_i[k][0]-obs_width/2)
            obs_pt.append(obs_i[k][1]-obs_width/2)
            obs_pt.append(obs_i[k][0]-obs_width/2)
            obs_pt.append(obs_i[k][1]+obs_width/2)
            obs_pt.append(obs_i[k][0]+obs_width/2)
            obs_pt.append(obs_i[k][1]+obs_width/2)
            obs_pt.append(obs_i[k][0]+obs_width/2)
            obs_pt.append(obs_i[k][1]-obs_width/2)
            new_obs_i.append(obs_pt)
        obs_i = new_obs_i
        #print(obs_i)
        for j in range(len(paths[i])):
            start_state = sgs[i][j][0]
            goal_inform_state = paths[i][j][-1]
            goal_state = sgs[i][j][1]
            #p = Process(target=plan_one_path, args=(obs[i], obc[i], start_state, goal_state, 500, queue))
            
            # propagate data
            p_start = paths[i][j][0]
            detail_paths = [p_start]
            detail_controls = []
            detail_costs = []
            state = [p_start]
            control = []
            cost = []
            for k in range(len(controls[i][j])):
                #state_i.append(len(detail_paths)-1)
                max_steps = int(costs[i][j][k]/step_sz)
                accum_cost = 0.
                #print('p_start:')
                #print(p_start)
                #print('data:')
                #print(paths[i][j][k])
                # modify it because of small difference between data and actual propagation
                #p_start = paths[i][j][k]
                #state[-1] = paths[i][j][k]
                for step in range(1,max_steps+1):
                    p_start = dynamics(p_start, controls[i][j][k], step_sz)
                    p_start = enforce_bounds(p_start)
                    detail_paths.append(p_start)
                    detail_controls.append(controls[i][j])
                    detail_costs.append(step_sz)
                    accum_cost += step_sz
                    if (step % 1 == 0) or (step == max_steps):
                        state.append(p_start)
                        #print('control')
                        #print(controls[i][j])
                        control.append(controls[i][j][k])
                        cost.append(accum_cost)
                        accum_cost = 0.
            #print('p_start:')
            #print(p_start)
            #print('data:')
            #print(paths[i][j][-1])
            state[-1] = paths[i][j][-1]
            data = state

            plan_one_path(obs_i, obs[i], obc[i], start_state, goal_state, goal_inform_state, 1000, data, queue)
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    #torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    #if torch.cuda.is_available():
    #    torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        obs_file = None
        obc_file = None
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        step_sz = 0.002
        num_steps = 21
        goal_radius = 1.5
        random_seed = 0
        delta_near = 2.0
        delta_drain = 1.2
        cost_threshold = 1.2
        min_time_steps = 10
        max_time_steps = 200
        integration_step = 0.002
        obs_width = 4.0
        obs_f = True
        system = _sst_module.PSOPTCartPole()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'cartpole')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs':
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 200, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    if args.env_type == 'pendulum':
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole_obs':
        #system = standard_cpp_systems.RectangleObs(obs[i], 4.0, 'cartpole')
        step_sz = 0.002
        num_steps = 21
        goal_radius = 1.5
        random_seed = 0
        delta_near = 2.0
        delta_drain = 1.2
        cost_threshold = 1.2
        min_time_steps = 10
        max_time_steps = 200
        integration_step = 0.002
        obs_width = 4.0
        obs_f = True
        IsInCollision = cartpole_IsInCollision
        enforce_bounds = cartpole_enforce_bounds
    elif args.env_type in [
            'acrobot_obs', 'acrobot_obs_2', 'acrobot_obs_3', 'acrobot_obs_4',
            'acrobot_obs_8'
    ]:
        #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
        obs_width = 6.0
        step_sz = 0.02
        num_steps = 21
        goal_radius = 2.0
        random_seed = 0
        delta_near = 0.1
        delta_drain = 0.05

    # load previously trained model if start epoch > 0
    #model_path='kmpnet_epoch_%d_direction_0_step_%d.pkl' %(args.start_epoch, args.num_steps)
    mlp_path = os.path.join(
        os.getcwd() + '/c++/', '%s_MLP_lr%f_epoch_%d_step_%d.pt' %
        (args.env_type, args.learning_rate, args.start_epoch, args.num_steps))
    encoder_path = os.path.join(
        os.getcwd() + '/c++/', '%s_encoder_lr%f_epoch_%d_step_%d.pt' %
        (args.env_type, args.learning_rate, args.start_epoch, args.num_steps))
    #mlp_path = os.path.join(os.getcwd()+'/c++/','acrobot_obs_MLP_epoch_5000.pt')
    #encoder_path = os.path.join(os.getcwd()+'/c++/','acrobot_obs_encoder_epoch_5000.pt')

    #cost_mlp_path = os.path.join(os.getcwd()+'/c++/','costnet_%s_MLP_lr%f_epoch_%d_step_%d.pt' % (args.env_type, args.learning_rate, args.start_epoch, args.num_steps))
    #cost_encoder_path = os.path.join(os.getcwd()+'/c++/','costnet_%s_encoder_lr%f_epoch_%d_step_%d.pt' % (args.env_type, args.learning_rate, args.start_epoch, args.num_steps))
    cost_mlp_path = os.path.join(
        os.getcwd() + '/c++/', 'costnet_acrobot_obs_MLP_epoch_800_step_10.pt')
    cost_encoder_path = os.path.join(
        os.getcwd() + '/c++/',
        'costnet_acrobot_obs_encoder_epoch_800_step_10.pt')

    print('mlp_path:')
    print(mlp_path)

    #####################################################
    def plan_one_path(obs_i, obs, obc, start_state, goal_state,
                      goal_inform_state, cost_i, max_iteration, data,
                      out_queue):
        if args.env_type == 'pendulum':
            system = standard_cpp_systems.PSOPTPendulum()
            bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
            step_sz = 0.002
            num_steps = 20
            traj_opt = lambda x0, x1: bvp_solver.solve(x0, x1, 200, num_steps,
                                                       1, 20, step_sz)
        elif args.env_type == 'cartpole_obs':
            obs_width = 4.0
            psopt_system = _sst_module.PSOPTCartPole()
            propagate_system = standard_cpp_systems.RectangleObs(
                obs, obs_width, 'cartpole')
            distance_computer = propagate_system.distance_computer()
            #distance_computer = _sst_module.euclidean_distance(np.array(propagate_system.is_circular_topology()))
            step_sz = 0.002
            num_steps = 101
            goal_radius = 1.5
            random_seed = 0
            delta_near = 2.0
            delta_drain = 1.2
            device = 3
            num_sample = 10
            min_time_steps = 10
            max_time_steps = 200
            mpnet_goal_threshold = 2.0
            mpnet_length_threshold = 40
            pick_goal_init_threshold = 0.1
            pick_goal_end_threshold = 0.8
            pick_goal_start_percent = 0.4
        elif args.env_type in [
                'acrobot_obs', 'acrobot_obs_2', 'acrobot_obs_3',
                'acrobot_obs_4', 'acrobot_obs_8'
        ]:
            #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
            obs_width = 6.0
            psopt_system = _sst_module.PSOPTAcrobot()
            propagate_system = standard_cpp_systems.RectangleObs(
                obs, 6., 'acrobot')
            distance_computer = propagate_system.distance_computer()
            #distance_computer = _sst_module.euclidean_distance(np.array(propagate_system.is_circular_topology()))
            step_sz = 0.02
            num_steps = 21
            goal_radius = 2.0
            random_seed = 0
            delta_near = 1.0
            delta_drain = 0.5
        #print('creating planner...')
        planner = vis_planners.DeepSMPWrapper(mlp_path, encoder_path, cost_mlp_path, cost_encoder_path, \
                                              200, num_steps, step_sz, propagate_system, 3)
        #cost_threshold = cost_i * 1.1
        cost_threshold = 100000000.
        """
        # visualization
        plt.ion()
        fig = plt.figure()
        ax = fig.add_subplot(111)
        #ax.set_autoscale_on(True)
        ax.set_xlim(-np.pi, np.pi)
        ax.set_ylim(-np.pi, np.pi)
        hl, = ax.plot([], [], 'b')
        #hl_real, = ax.plot([], [], 'r')
        hl_for, = ax.plot([], [], 'g')
        hl_back, = ax.plot([], [], 'r')
        hl_for_mpnet, = ax.plot([], [], 'lightgreen')
        hl_back_mpnet, = ax.plot([], [], 'salmon')

        #print(obs)
        def update_line(h, ax, new_data):
            new_data = wrap_angle(new_data, propagate_system)
            h.set_data(np.append(h.get_xdata(), new_data[0]), np.append(h.get_ydata(), new_data[1]))
            #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
            #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

        def remove_last_k(h, ax, k):
            h.set_data(h.get_xdata()[:-k], h.get_ydata()[:-k])

        def draw_update_line(ax):
            #ax.relim()
            #ax.autoscale_view()
            fig.canvas.draw()
            fig.canvas.flush_events()
            #plt.show()

        def wrap_angle(x, system):
            circular = system.is_circular_topology()
            res = np.array(x)
            for i in range(len(x)):
                if circular[i]:
                    # use our previously saved version
                    res[i] = x[i] - np.floor(x[i] / (2*np.pi))*(2*np.pi)
                    if res[i] > np.pi:
                        res[i] = res[i] - 2*np.pi
            return res
        dtheta = 0.1
        feasible_points = []
        infeasible_points = []
        imin = 0
        imax = int(2*np.pi/dtheta)


        for i in range(imin, imax):
            for j in range(imin, imax):
                x = np.array([dtheta*i-np.pi, dtheta*j-np.pi, 0., 0.])
                if IsInCollision(x, obs_i):
                    infeasible_points.append(x)
                else:
                    feasible_points.append(x)
        feasible_points = np.array(feasible_points)
        infeasible_points = np.array(infeasible_points)
        print('feasible points')
        print(feasible_points)
        print('infeasible points')
        print(infeasible_points)
        ax.scatter(feasible_points[:,0], feasible_points[:,1], c='yellow')
        ax.scatter(infeasible_points[:,0], infeasible_points[:,1], c='pink')
        for i in range(len(data)):
            update_line(hl, ax, data[i])
        draw_update_line(ax)        
        # visualization end
        """

        # visualize for cartpole
        plt.ion()
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.set_autoscale_on(True)
        hl, = ax.plot([], [], 'b')

        #hl_real, = ax.plot([], [], 'r')
        def update_line(h, ax, new_data):
            h.set_data(np.append(h.get_xdata(), new_data[0]),
                       np.append(h.get_ydata(), new_data[1]))
            #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
            #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

        def draw_update_line(ax):
            ax.relim()
            ax.autoscale_view()
            fig.canvas.draw()
            fig.canvas.flush_events()

        # randomly pick up a point in the data, and find similar data in the dataset
        # plot the next point
        #ax.set_autoscale_on(True)
        ax.set_xlim(-30, 30)
        ax.set_ylim(-np.pi, np.pi)
        hl, = ax.plot([], [], 'b')
        #hl_real, = ax.plot([], [], 'r')
        hl_for, = ax.plot([], [], 'g')
        hl_back, = ax.plot([], [], 'r')
        hl_for_mpnet, = ax.plot([], [], 'lightgreen')
        hl_back_mpnet, = ax.plot([], [], 'salmon')

        #print(obs)
        def update_line(h, ax, new_data):
            new_data = wrap_angle(new_data, propagate_system)
            h.set_data(np.append(h.get_xdata(), new_data[0]),
                       np.append(h.get_ydata(), new_data[1]))
            #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
            #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

        def remove_last_k(h, ax, k):
            h.set_data(h.get_xdata()[:-k], h.get_ydata()[:-k])

        def draw_update_line(ax):
            #ax.relim()
            #ax.autoscale_view()
            fig.canvas.draw()
            fig.canvas.flush_events()
            #plt.show()

        def wrap_angle(x, system):
            circular = system.is_circular_topology()
            res = np.array(x)
            for i in range(len(x)):
                if circular[i]:
                    # use our previously saved version
                    res[i] = x[i] - np.floor(x[i] / (2 * np.pi)) * (2 * np.pi)
                    if res[i] > np.pi:
                        res[i] = res[i] - 2 * np.pi
            return res

        dx = 1
        dtheta = 0.1
        feasible_points = []
        infeasible_points = []
        imin = 0
        imax = int(2 * 30. / dx)
        jmin = 0
        jmax = int(2 * np.pi / dtheta)

        for i in range(imin, imax):
            for j in range(jmin, jmax):
                x = np.array([dx * i - 30, 0., dtheta * j - np.pi, 0.])
                if IsInCollision(x, obs_i):
                    infeasible_points.append(x)
                else:
                    feasible_points.append(x)
        feasible_points = np.array(feasible_points)
        infeasible_points = np.array(infeasible_points)
        print('feasible points')
        print(feasible_points)
        print('infeasible points')
        print(infeasible_points)
        ax.scatter(feasible_points[:, 0], feasible_points[:, 2], c='yellow')
        ax.scatter(infeasible_points[:, 0], infeasible_points[:, 2], c='pink')
        #for i in range(len(data)):
        #    update_line(hl, ax, data[i])
        draw_update_line(ax)
        #state_t = start_state
        # visualization end

        # generate a path by using SST to plan for some maximal iterations

        state_t = start_state
        pick_goal_threshold = .0
        ax.scatter(goal_inform_state[0],
                   goal_inform_state[2],
                   marker='*',
                   c='red')  # cartpole

        for i in range(max_iteration):
            time0 = time.time()
            # determine if picking goal based on iteration number
            goal_prob = random.random()
            #flag=1: using MPNet
            #flag=0: not using MPNet
            if goal_prob <= pick_goal_threshold:
                flag = 0
            else:
                flag = 1
            bvp_x, bvp_u, bvp_t, mpnet_res = planner.plan_tree_SMP_step("sst", propagate_system, psopt_system, obc.flatten(), state_t, goal_inform_state, goal_inform_state, \
                                flag, goal_radius, max_iteration, distance_computer, \
                                delta_near, delta_drain, cost_threshold)

            if len(
                    bvp_u
            ) != 0:  # and bvp_t[0] > 0.01:  # turn bvp_t off if want to use step_bvp
                xw_scat = ax.scatter(mpnet_res[0],
                                     mpnet_res[2],
                                     c='lightgreen')
                draw_update_line(ax)

                # propagate data
                p_start = bvp_x[0]
                detail_paths = [p_start]
                detail_controls = []
                detail_costs = []
                state = [p_start]
                control = []
                cost = []
                for k in range(len(bvp_u)):
                    #state_i.append(len(detail_paths)-1)
                    max_steps = int(bvp_t[k] / step_sz)
                    accum_cost = 0.
                    for step in range(1, max_steps + 1):
                        p_start = dynamics(p_start, bvp_u[k], step_sz)
                        p_start = enforce_bounds(p_start)
                        detail_paths.append(p_start)
                        accum_cost += step_sz
                        if (step % 1 == 0) or (step == max_steps):
                            state.append(p_start)
                            cost.append(accum_cost)
                            accum_cost = 0.

                xs_to_plot = np.array(state)
                for j in range(len(xs_to_plot)):
                    xs_to_plot[j] = wrap_angle(xs_to_plot[j], propagate_system)
                xs_to_plot = xs_to_plot[::5]
                #ax.scatter(xs_to_plot[:,0], xs_to_plot[:,1], c='green')  # acrobot
                ax.scatter(xs_to_plot[:, 0], xs_to_plot[:, 2],
                           c='green')  # cartpole

                #ax.scatter(bvp_x[:,0], bvp_x[:,1], c='green')
                print('solution: x')
                print(bvp_x)
                print('solution: u')
                print(bvp_u)
                print('solution: t')
                print(bvp_t)
                # draw start and goal
                #ax.scatter(start_state[0], goal_state[0], marker='X')
                draw_update_line(ax)
                #state_t = state[-1]

            print('state_t:')
            print(state_t)
            print('mpnet_res')
            print(mpnet_res)
            # based on flag, determine how to change state_t
            if flag:
                # only change state_t if in MPNet inform mode
                #if len(bvp_u) != 0:
                if True:
                    # try using steered result as next start
                    #if not IsInCollision(state_t, obs_i):
                    state_t = mpnet_res
                    print('after copying to state_t:')
                    print('state_t')
                    print(state_t)
                    # if in collision, then not using it
                else:
                    print('failure')
                    state_t = start_state  # failed BVP, back to origin

        plan_time = time.time() - time0

        print('plan time: %fs' % (plan_time))
        if len(res_u) == 0:
            print('failed.')
            out_queue.put(-1)
        else:
            print('path succeeded.')
            print('cost: %f' % (np.sum(res_t)))
            print('cost_threshold: %f' % (cost_threshold))
            print('data cost: %f' % (cost_i))
            out_queue.put(plan_time)

    ####################################################################################

    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N,
                                                       args.seen_NP,
                                                       args.data_folder, obs_f,
                                                       args.seen_s,
                                                       args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            args.unseen_N, args.unseen_NP, args.data_folder, obs_f,
            args.unseen_s, args.unseen_sp)
    # test
    # testing

    queue = Queue(1)
    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.

    obc, obs, paths, sgs, path_lengths, controls, costs = seen_test_data
    obc = obc.astype(np.float32)
    #obc = torch.from_numpy(obc)
    #if torch.cuda.is_available():
    #    obc = obc.cuda()

    plan_res = []
    plan_times = []
    plan_res_all = []
    for i in range(len(paths)):
        new_obs_i = []
        obs_i = obs[i]
        plan_res_env = []
        plan_time_env = []
        for k in range(len(obs_i)):
            obs_pt = []
            obs_pt.append(obs_i[k][0] - obs_width / 2)
            obs_pt.append(obs_i[k][1] - obs_width / 2)
            obs_pt.append(obs_i[k][0] - obs_width / 2)
            obs_pt.append(obs_i[k][1] + obs_width / 2)
            obs_pt.append(obs_i[k][0] + obs_width / 2)
            obs_pt.append(obs_i[k][1] + obs_width / 2)
            obs_pt.append(obs_i[k][0] + obs_width / 2)
            obs_pt.append(obs_i[k][1] - obs_width / 2)
            new_obs_i.append(obs_pt)
        obs_i = new_obs_i
        #print(obs_i)
        for j in range(len(paths[i])):
            start_state = sgs[i][j][0]
            goal_inform_state = paths[i][j][-1]
            goal_state = sgs[i][j][1]
            cost_i = costs[i][j].sum()
            #cost_i = 100000000.

            # propagate data
            p_start = paths[i][j][0]
            detail_paths = [p_start]
            detail_controls = []
            detail_costs = []
            state = [p_start]
            control = []
            cost = []
            for k in range(len(controls[i][j])):
                #state_i.append(len(detail_paths)-1)
                max_steps = int(costs[i][j][k] / step_sz)
                accum_cost = 0.
                #print('p_start:')
                #print(p_start)
                #print('data:')
                #print(paths[i][j][k])
                # modify it because of small difference between data and actual propagation
                p_start = paths[i][j][k]
                state[-1] = paths[i][j][k]
                for step in range(1, max_steps + 1):
                    p_start = dynamics(p_start, controls[i][j][k], step_sz)
                    p_start = enforce_bounds(p_start)
                    detail_paths.append(p_start)
                    detail_controls.append(controls[i][j])
                    detail_costs.append(step_sz)
                    accum_cost += step_sz
                    if (step % 1 == 0) or (step == max_steps):
                        state.append(p_start)
                        #print('control')
                        #print(controls[i][j])
                        control.append(controls[i][j][k])
                        cost.append(accum_cost)
                        accum_cost = 0.
            #print('p_start:')
            #print(p_start)
            #print('data:')
            #print(paths[i][j][-1])
            state[-1] = paths[i][j][-1]
            data = state
            # end of propagation

            print('environment: %d/%d, path: %d/%d' %
                  (i + 1, len(paths), j + 1, len(paths[i])))
            #p = Process(target=plan_one_path, args=(obs_i, obs[i], obc[i], start_state, goal_state, goal_inform_state, cost_i, 300000, data, queue))
            plan_one_path(obs_i, obs[i], obc[i], start_state, goal_state,
                          goal_inform_state, cost_i, 300000, data, queue)
            #p.start()
            #p.join()
            res = queue.get()
            if res == -1:
                plan_res_env.append(0)
                plan_res_all.append(0)
            else:
                plan_res_env.append(1)
                plan_times.append(res)
                plan_res_all.append(1)
            print('average accuracy up to now: %f' %
                  (np.array(plan_res_all).flatten().mean()))
            print('plan average time: %f' % (np.array(plan_times).mean()))
            print('plan time std: %f' % (np.array(plan_times).std()))
        plan_res.append(plan_res_env)
    print('plan accuracy: %f' % (np.array(plan_res).flatten().mean()))
    print('plan average time: %f' % (np.array(plan_times).mean()))
    print('plan time std: %f' % (np.array(plan_times).std()))
Example #8
0
def main(args):
    #global hl
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
    # environment setting
    cae = cae_identity
    mlp = MLP
    cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        system = standard_cpp_systems.PSOPTPendulum()
        dynamics = None
        enforce_bounds = None
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole':
        normalize = cart_pole.normalize
        unnormalize = cart_pole.unnormalize
        dynamics = cartpole.dynamics
        system = _sst_module.CartPole()
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.CartPole()
        dynamics = cartpole.dynamics
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'acrobot_obs':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        obs_width = 6.0
        IsInCollision = acrobot_obs.IsInCollision
    elif args.env_type == 'acrobot_obs_2':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP2
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_3':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_4':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_5':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_6':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP4
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_7':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP5
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_8':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20

    mpnet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp)
    # load net
    # load previously trained model if start epoch > 0
    model_dir = args.model_dir
    model_dir = model_dir + 'cost_' + args.env_type + "_lr%f_%s_step_%d/" % (
        args.learning_rate, args.opt, args.num_steps)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = 'cost_kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
        args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet, os.path.join(model_dir, model_path))
        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(model_dir, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)
    """
    if torch.cuda.is_available():
        mpnet.cuda()
        mpnet.mlp.cuda()
        mpnet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)
    """
    if args.start_epoch > 0:
        #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
        load_opt_state(mpnet, os.path.join(model_dir, model_path))

    # load train and test data
    print('loading...')
    seen_test_data = data_loader.load_test_dataset(args.seen_N, args.seen_NP,
                                                   args.path_folder, True,
                                                   args.seen_s, args.seen_sp)
    obc, obs, paths, sgs, path_lengths, controls, costs = seen_test_data
    obc = obc.astype(np.float32)

    for pi in range(len(paths)):
        new_obs_i = []
        obs_i = obs[pi]
        plan_res_env = []
        plan_time_env = []
        for k in range(len(obs_i)):
            obs_pt = []
            obs_pt.append(obs_i[k][0] - obs_width / 2)
            obs_pt.append(obs_i[k][1] - obs_width / 2)
            obs_pt.append(obs_i[k][0] - obs_width / 2)
            obs_pt.append(obs_i[k][1] + obs_width / 2)
            obs_pt.append(obs_i[k][0] + obs_width / 2)
            obs_pt.append(obs_i[k][1] + obs_width / 2)
            obs_pt.append(obs_i[k][0] + obs_width / 2)
            obs_pt.append(obs_i[k][1] - obs_width / 2)
            new_obs_i.append(obs_pt)
        obs_i = new_obs_i

        for pj in range(len(paths[pi])):

            # on the entire state space, visualize the cost
            # visualization
            """
            plt.ion()
            fig = plt.figure()
            ax = fig.add_subplot(111)
            #ax.set_autoscale_on(True)
            ax.set_xlim(-np.pi, np.pi)
            ax.set_ylim(-np.pi, np.pi)
            hl, = ax.plot([], [], 'b')
            #hl_real, = ax.plot([], [], 'r')
            hl_for, = ax.plot([], [], 'g')
            hl_back, = ax.plot([], [], 'r')
            hl_for_mpnet, = ax.plot([], [], 'lightgreen')
            hl_back_mpnet, = ax.plot([], [], 'salmon')

            #print(obs)
            def update_line(h, ax, new_data):
                new_data = wrap_angle(new_data, propagate_system)
                h.set_data(np.append(h.get_xdata(), new_data[0]), np.append(h.get_ydata(), new_data[1]))
                #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
                #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

            def remove_last_k(h, ax, k):
                h.set_data(h.get_xdata()[:-k], h.get_ydata()[:-k])

            def draw_update_line(ax):
                #ax.relim()
                #ax.autoscale_view()
                fig.canvas.draw()
                fig.canvas.flush_events()
                #plt.show()

            def wrap_angle(x, system):
                circular = system.is_circular_topology()
                res = np.array(x)
                for i in range(len(x)):
                    if circular[i]:
                        # use our previously saved version
                        res[i] = x[i] - np.floor(x[i] / (2*np.pi))*(2*np.pi)
                        if res[i] > np.pi:
                            res[i] = res[i] - 2*np.pi
                return res
            """
            dtheta = 0.1
            feasible_points = []
            infeasible_points = []

            imin = 0
            imax = int(2 * np.pi / dtheta)

            x0 = paths[pi][pj][0]
            xT = paths[pi][pj][-1]
            # visualize the cost on all grids
            costmaps = []
            cost_to_come = []
            cost_to_go = []
            for i in range(imin, imax):
                costmaps_i = []
                for j in range(imin, imax):
                    x = np.array(
                        [dtheta * i - np.pi, dtheta * j - np.pi, 0., 0.])
                    cost_to_come_in = np.array([np.concatenate([x0, x])])
                    cost_to_come_in = torch.from_numpy(cost_to_come_in).type(
                        torch.FloatTensor)
                    cost_to_come_in = normalize(cost_to_come_in,
                                                args.world_size)
                    cost_to_go_in = np.array([np.concatenate([x, xT])])
                    cost_to_go_in = torch.from_numpy(cost_to_go_in).type(
                        torch.FloatTensor)
                    cost_to_go_in = normalize(cost_to_go_in, args.world_size)

                    cost_to_come.append(cost_to_come_in)
                    cost_to_go.append(cost_to_go_in)
            cost_to_come = torch.cat(cost_to_come, 0)
            cost_to_go = torch.cat(cost_to_go, 0)
            print(cost_to_go.size())
            obc_i_torch = torch.from_numpy(np.array([obc[pi]])).type(
                torch.FloatTensor).repeat(len(cost_to_go), 1, 1, 1)
            print(obc_i_torch.size())
            cost_sum = mpnet(cost_to_come, obc_i_torch) + mpnet(
                cost_to_go, obc_i_torch)
            cost_to_come_val = mpnet(cost_to_come,
                                     obc_i_torch).detach().numpy().reshape(
                                         imax - imin, -1)
            cost_to_go_val = mpnet(cost_to_go,
                                   obc_i_torch).detach().numpy().reshape(
                                       imax - imin, -1)
            print('cost_to_come:')
            print(cost_to_come_val)
            print('cost_to_come[(imax+imin)//2,(imax+imin)//2]: ',
                  cost_to_come_val[(imax + imin) // 2, (imax + imin) // 2])
            print('cost_to_go_val:')
            print(cost_to_go_val)
            cost_sum = cost_sum[:, 0].detach().numpy().reshape(imax - imin, -1)
            for i in range(imin, imax):
                costmaps_i = []
                for j in range(imin, imax):
                    costmaps_i.append(cost_sum[i][j])
                    #if IsInCollision(x, obs_i):
                    #    costmaps_i.append(1000.)
                    #else:
                    #    costmaps_i.append(cost_sum[i][j])
                costmaps.append(costmaps_i)
            costmaps = np.array(costmaps)
            # plot the costmap
            print(costmaps)
            print(costmaps.min())
            print(costmaps.max())
            costmaps = costmaps - costmaps.min() + 1.0  # map to 1.0 to infty
            costmaps = np.log(costmaps)
            im = plt.imshow(costmaps, cmap='hot', interpolation='nearest')

            for i in range(imin, imax):
                for j in range(imin, imax):
                    x = np.array(
                        [dtheta * i - np.pi, dtheta * j - np.pi, 0., 0.])
                    if IsInCollision(x, obs_i):
                        infeasible_points.append(x)
                    else:
                        feasible_points.append(x)
            feasible_points = np.array(feasible_points)
            infeasible_points = np.array(infeasible_points)
            print('feasible points')
            print(feasible_points)
            print('infeasible points')
            print(infeasible_points)
            #ax.scatter(feasible_points[:,0], feasible_points[:,1], c='yellow')
            #ax.scatter(infeasible_points[:,0], infeasible_points[:,1], c='pink')
            #for i in range(len(data)):
            #    update_line(hl, ax, data[i])
            #draw_update_line(ax)
            #state_t = start_state

            plt.colorbar(im)
            plt.show()
            plt.waitforbuttonpress()
Example #9
0
def main(args):
    # load MPNet
    #global hl
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)

    if args.debug:
        from sparse_rrt import _sst_module
        from plan_utility import cart_pole, cart_pole_obs, pendulum, acrobot_obs
        from tools import data_loader

        cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        if args.debug:
            normalize = pendulum.normalize
            unnormalize = pendulum.unnormalize
            system = standard_cpp_systems.PSOPTPendulum()
            dynamics = None
            enforce_bounds = None
            step_sz = 0.002
            num_steps = 20

    elif args.env_type == 'cartpole':
        if args.debug:
            normalize = cart_pole.normalize
            unnormalize = cart_pole.unnormalize
            dynamics = cartpole.dynamics
            system = _sst_module.CartPole()
            enforce_bounds = cartpole.enforce_bounds
            step_sz = 0.002
            num_steps = 20
    elif args.env_type == 'cartpole_obs':
        if args.debug:
            normalize = cart_pole_obs.normalize
            unnormalize = cart_pole_obs.unnormalize
            system = _sst_module.CartPole()
            dynamics = cartpole.dynamics
            enforce_bounds = cartpole.enforce_bounds
            step_sz = 0.002
            num_steps = 20
    elif args.env_type == 'acrobot_obs':
        if args.debug:
            normalize = acrobot_obs.normalize
            unnormalize = acrobot_obs.unnormalize
            system = _sst_module.PSOPTAcrobot()
            #dynamics = acrobot_obs.dynamics
            dynamics = lambda x, u, t: cpp_propagator.propagate(
                system, x, u, t)
            enforce_bounds = acrobot_obs.enforce_bounds
            step_sz = 0.02
            num_steps = 20
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d

    elif args.env_type == 'acrobot_obs_8':
        if args.debug:
            normalize = acrobot_obs.normalize
            unnormalize = acrobot_obs.unnormalize
            system = _sst_module.PSOPTAcrobot()
            #dynamics = acrobot_obs.dynamics
            dynamics = lambda x, u, t: cpp_propagator.propagate(
                system, x, u, t)
            enforce_bounds = acrobot_obs.enforce_bounds
            step_sz = 0.02
            num_steps = 20
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3

    mpnet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp)
    # load net
    # load previously trained model if start epoch > 0
    model_dir = args.model_dir
    model_dir = model_dir + 'cost_' + args.env_type + "_lr%f_%s_step_%d/" % (
        args.learning_rate, args.opt, args.num_steps)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = 'cost_kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
        args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet, os.path.join(model_dir, model_path))
        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(model_dir, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet.cuda()
        mpnet.mlp.cuda()
        mpnet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)
    if args.start_epoch > 0:
        #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
        load_opt_state(mpnet, os.path.join(model_dir, model_path))

    # load train and test data
    print('loading...')
    if args.debug:
        obs, cost_dataset, cost_targets, env_indices, \
        _, _, _, _ = data_loader.load_train_dataset_cost(N=args.no_env, NP=args.no_motion_paths,
                                                    data_folder=args.path_folder, obs_f=True,
                                                    direction=args.direction,
                                                    dynamics=dynamics, enforce_bounds=enforce_bounds,
                                                    system=system, step_sz=step_sz, num_steps=args.num_steps)
        # randomize the dataset before training
        data = list(zip(cost_dataset, cost_targets, env_indices))
        random.shuffle(data)
        dataset, targets, env_indices = list(zip(*data))
        dataset = list(dataset)
        targets = list(targets)
        env_indices = list(env_indices)
        dataset = np.array(dataset)
        targets = np.array(targets)
        env_indices = np.array(env_indices)
        # record
        bi = dataset.astype(np.float32)
        print('bi shape:')
        print(bi.shape)
        bt = targets
        bi = torch.FloatTensor(bi)
        bt = torch.FloatTensor(bt)
        bi = normalize(bi, args.world_size)
        bi = to_var(bi)
        bt = to_var(bt)
        if obs is None:
            bobs = None
        else:
            bobs = obs[env_indices].astype(np.float32)
            bobs = torch.FloatTensor(bobs)
            bobs = to_var(bobs)
    else:
        bobs = np.random.rand(1, 1, args.AE_input_size, args.AE_input_size)
        bobs = torch.from_numpy(bobs).type(torch.FloatTensor)
        bobs = to_var(bobs)
        bi = np.random.rand(1, args.total_input_size)
        bt = np.random.rand(1, args.output_size)
        bi = torch.from_numpy(bi).type(torch.FloatTensor)
        bt = torch.from_numpy(bt).type(torch.FloatTensor)
        bi = to_var(bi)
        bt = to_var(bt)
    # set to training model to enable dropout
    #mpnet.train()
    mpnet.eval()

    MLP = mpnet.mlp
    encoder = mpnet.encoder
    traced_encoder = torch.jit.trace(encoder, (bobs))
    encoder_output = encoder(bobs)
    mlp_input = torch.cat((encoder_output, bi), 1)
    traced_MLP = torch.jit.trace(MLP, (mlp_input))
    traced_encoder.save("costnet_%s_encoder_epoch_%d_step_%d.pt" %
                        (args.env_type, args.start_epoch, args.num_steps))
    traced_MLP.save("costnet_%s_MLP_epoch_%d_step_%d.pt" %
                    (args.env_type, args.start_epoch, args.num_steps))

    # test the traced model
    serilized_encoder = torch.jit.script(encoder)
    serilized_MLP = torch.jit.script(MLP)
    serilized_encoder_output = serilized_encoder(bobs)
    serilized_MLP_input = torch.cat((serilized_encoder_output, bi), 1)
    serilized_MLP_output = serilized_MLP(serilized_MLP_input)
    print('encoder output: ', serilized_encoder_output)
    print('MLP output: ', serilized_MLP_output)
    print('data: ', bt)
def main(args):
    #global hl
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
    # environment setting
    cae = cae_identity
    mlp = MLP
    cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        system = standard_cpp_systems.PSOPTPendulum()
        dynamics = None
        enforce_bounds = None
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole':
        normalize = cart_pole.normalize
        unnormalize = cart_pole.unnormalize
        dynamics = cartpole.dynamics
        system = _sst_module.CartPole()
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.CartPole()
        dynamics = cartpole.dynamics
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'acrobot_obs':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_2':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP2
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_3':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_4':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_5':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_6':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP4
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_7':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP5
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_8':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20

    mpnet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp)
    # load net
    # load previously trained model if start epoch > 0
    model_dir = args.model_dir
    model_dir = model_dir + 'cost_' + args.env_type + "_lr%f_%s_step_%d/" % (
        args.learning_rate, args.opt, args.num_steps)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = 'cost_kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
        args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet, os.path.join(model_dir, model_path))
        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(model_dir, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet.cuda()
        mpnet.mlp.cuda()
        mpnet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)
    if args.start_epoch > 0:
        #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
        load_opt_state(mpnet, os.path.join(model_dir, model_path))
    mpnet.eval()
    # load train and test data
    print('loading...')
    obs, cost_dataset, cost_targets, env_indices, \
    _, _, _, _ = data_loader.load_train_dataset_cost(N=args.no_env, NP=args.no_motion_paths,
                                                data_folder=args.path_folder, obs_f=True,
                                                direction=args.direction,
                                                dynamics=dynamics, enforce_bounds=enforce_bounds,
                                                system=system, step_sz=step_sz, num_steps=args.num_steps)
    # randomize the dataset before training
    data = list(zip(cost_dataset, cost_targets, env_indices))
    random.shuffle(data)
    dataset, targets, env_indices = list(zip(*data))
    dataset = list(dataset)
    targets = list(targets)
    env_indices = list(env_indices)
    dataset = np.array(dataset)
    targets = np.array(targets)
    env_indices = np.array(env_indices)

    val_i = 0
    for i in range(0, len(dataset), args.batch_size):
        # validation
        # calculate the corresponding batch in val_dataset
        dataset_i = dataset[i:i + args.batch_size]
        targets_i = targets[i:i + args.batch_size]
        env_indices_i = env_indices[i:i + args.batch_size]
        # record
        bi = dataset_i.astype(np.float32)
        print('bi shape:')
        print(bi.shape)
        bt = targets_i
        bi = torch.FloatTensor(bi)
        bt = torch.FloatTensor(bt)
        bi = normalize(bi, args.world_size)
        bi = to_var(bi)
        bt = to_var(bt)
        if obs is None:
            bobs = None
        else:
            bobs = obs[env_indices_i].astype(np.float32)
            bobs = torch.FloatTensor(bobs)
            bobs = to_var(bobs)
        print('cost network output: ')
        print(mpnet(bi, bobs).cpu().data)
        print('target: ')
        print(bt.cpu().data)
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    #torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    #if torch.cuda.is_available():
    #    torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        obs_file = None
        obc_file = None
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        normalize = cartpole.normalize
        unnormalize = cartpole.unnormalize
        obs_file = None
        obc_file = None
        #dynamics = cartpole.dynamics
        #jax_dynamics = cartpole.jax_dynamics
        #enforce_bounds = cartpole.enforce_bounds
        cae = CAE_acrobot_voxel_2d
        mlp = mlp_acrobot.MLP
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'cartpole')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs':
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 200, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        obs_width = 6.0
        step_sz = 0.02
        num_steps = 21
        goal_radius = 2.0
        random_seed = 0
        delta_near = 0.1
        delta_drain = 0.05

    elif args.env_type in [
            'acrobot_obs', 'acrobot_obs_2', 'acrobot_obs_3', 'acrobot_obs_4',
            'acrobot_obs_8'
    ]:
        #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
        obs_width = 6.0
        step_sz = 0.02
        num_steps = 21
        goal_radius = 2.0
        random_seed = 0
        delta_near = 0.1
        delta_drain = 0.05

    # load previously trained model if start epoch > 0
    #model_path='kmpnet_epoch_%d_direction_0_step_%d.pkl' %(args.start_epoch, args.num_steps)
    mlp_path = os.path.join(
        os.getcwd() + '/c++/',
        'acrobot_obs_MLP_lr0.010000_epoch_2850_step_20.pt')
    encoder_path = os.path.join(
        os.getcwd() + '/c++/',
        'acrobot_obs_encoder_lr0.010000_epoch_2850_step_20.pt')
    cost_mlp_path = os.path.join(
        os.getcwd() + '/c++/',
        'costnet_acrobot_obs_8_MLP_epoch_300_step_20.pt')
    cost_encoder_path = os.path.join(
        os.getcwd() + '/c++/',
        'costnet_acrobot_obs_8_encoder_epoch_300_step_20.pt')

    print('mlp_path:')
    print(mlp_path)

    #####################################################
    def plan_one_path(obs_i, obs, obc, start_state, goal_state,
                      goal_inform_state, max_iteration, data, out_queue):
        if args.env_type == 'pendulum':
            system = standard_cpp_systems.PSOPTPendulum()
            bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
            step_sz = 0.002
            num_steps = 20
            traj_opt = lambda x0, x1: bvp_solver.solve(x0, x1, 200, num_steps,
                                                       1, 20, step_sz)

        elif args.env_type == 'cartpole_obs':
            #system = standard_cpp_systems.RectangleObs(obs[i], 4.0, 'cartpole')
            system = _sst_module.CartPole()
            bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
            step_sz = 0.002
            num_steps = 20
            traj_opt = lambda x0, x1, x_init, u_init, t_init: bvp_solver.solve(
                x0, x1, 200, num_steps, step_sz * 1, step_sz * 50, x_init,
                u_init, t_init)
            goal_S0 = np.identity(4)
            goal_rho0 = 1.0
        elif args.env_type in [
                'acrobot_obs', 'acrobot_obs_2', 'acrobot_obs_3',
                'acrobot_obs_4', 'acrobot_obs_8'
        ]:
            #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
            obs_width = 6.0
            psopt_system = _sst_module.PSOPTAcrobot()
            propagate_system = standard_cpp_systems.RectangleObs(
                obs, 6., 'acrobot')
            distance_computer = propagate_system.distance_computer()
            #distance_computer = _sst_module.euclidean_distance(np.array(propagate_system.is_circular_topology()))

            step_sz = 0.02
            num_steps = 21
            goal_radius = 2
            random_seed = 0
            #delta_near=1.0
            #delta_drain=0.5
            delta_near = 0.1
            delta_drain = 0.05
        #print('creating planner...')
        planner = vis_planners.DeepSMPWrapper(mlp_path, encoder_path,
                                              cost_mlp_path, cost_encoder_path,
                                              20, num_steps, step_sz,
                                              propagate_system, args.device)
        # generate a path by using SST to plan for some maximal iterations
        time0 = time.time()
        #print('obc:')
        #print(obc.shape)
        #print(delta_near)
        #print(delta_drain)
        #print('start_state:')
        #print(start_state)
        #print('goal_state:')
        #print(goal_state)

        state_t = start_state

        pick_goal_threshold = 0.10
        goal_linear_inc_start_iter = int(0.6 * max_iteration)
        goal_linear_inc_end_iter = max_iteration
        goal_linear_inc_end_threshold = 0.95
        goal_linear_inc = (goal_linear_inc_end_threshold -
                           pick_goal_threshold) / (goal_linear_inc_end_iter -
                                                   goal_linear_inc_start_iter)

        start_time = time.time()
        plan_time = -1
        for i in tqdm(range(max_iteration)):
            print('iteration: %d' % (i))
            print('state_t:')
            print(state_t)
            # calculate if using goal or not
            use_goal_prob = random.random()
            if i > goal_linear_inc_start_iter:
                pick_goal_threshold += goal_linear_inc
            if use_goal_prob <= pick_goal_threshold:
                flag = 0
            else:
                flag = 1

            bvp_x, bvp_u, bvp_t, mpnet_res = planner.plan_step("sst", propagate_system, psopt_system, obc.flatten(), state_t, goal_inform_state, goal_inform_state, \
                                    flag, goal_radius, max_iteration, distance_computer, \
                                    delta_near, delta_drain)

            solution = planner.get_solution()
            if solution is not None:
                plan_time = time.time() - start_time

        if plan_time == -1:
            print('failed.')
            out_queue.put(-1)
        else:
            print('path succeeded.')
            out_queue.put(plan_time)  #if len(res_x) == 0:

    ####################################################################################

    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N,
                                                       args.seen_NP,
                                                       args.data_folder, obs_f,
                                                       args.seen_s,
                                                       args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            args.unseen_N, args.unseen_NP, args.data_folder, obs_f,
            args.unseen_s, args.unseen_sp)
    # test
    # testing

    queue = Queue(1)
    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.

    obc, obs, paths, sgs, path_lengths, controls, costs = seen_test_data
    obc = obc.astype(np.float32)
    #obc = torch.from_numpy(obc)
    #if torch.cuda.is_available():
    #    obc = obc.cuda()
    for i in range(len(paths)):
        new_obs_i = []
        obs_i = obs[i]
        for k in range(len(obs_i)):
            obs_pt = []
            obs_pt.append(obs_i[k][0] - obs_width / 2)
            obs_pt.append(obs_i[k][1] - obs_width / 2)
            obs_pt.append(obs_i[k][0] - obs_width / 2)
            obs_pt.append(obs_i[k][1] + obs_width / 2)
            obs_pt.append(obs_i[k][0] + obs_width / 2)
            obs_pt.append(obs_i[k][1] + obs_width / 2)
            obs_pt.append(obs_i[k][0] + obs_width / 2)
            obs_pt.append(obs_i[k][1] - obs_width / 2)
            new_obs_i.append(obs_pt)
        obs_i = new_obs_i
        #print(obs_i)
        for j in range(len(paths[i])):
            start_state = sgs[i][j][0]
            goal_inform_state = paths[i][j][-1]
            goal_state = sgs[i][j][1]
            #p = Process(target=plan_one_path, args=(obs[i], obc[i], start_state, goal_state, 500, queue))

            # propagate data
            p_start = paths[i][j][0]
            detail_paths = [p_start]
            detail_controls = []
            detail_costs = []
            state = [p_start]
            control = []
            cost = []
            for k in range(len(controls[i][j])):
                #state_i.append(len(detail_paths)-1)
                max_steps = int(costs[i][j][k] / step_sz)
                accum_cost = 0.
                #print('p_start:')
                #print(p_start)
                #print('data:')
                #print(paths[i][j][k])
                # modify it because of small difference between data and actual propagation
                p_start = paths[i][j][k]
                state[-1] = paths[i][j][k]
                for step in range(1, max_steps + 1):
                    p_start = dynamics(p_start, controls[i][j][k], step_sz)
                    p_start = enforce_bounds(p_start)
                    detail_paths.append(p_start)
                    detail_controls.append(controls[i][j])
                    detail_costs.append(step_sz)
                    accum_cost += step_sz
                    if (step % 1 == 0) or (step == max_steps):
                        state.append(p_start)
                        #print('control')
                        #print(controls[i][j])
                        control.append(controls[i][j][k])
                        cost.append(accum_cost)
                        accum_cost = 0.
            #print('p_start:')
            #print(p_start)
            #print('data:')
            #print(paths[i][j][-1])
            state[-1] = paths[i][j][-1]
            data = state
            p = Process(target=plan_one_path,
                        args=(obs_i, obs[i], obc[i], start_state, goal_state,
                              goal_inform_state, 1000, data, queue))
            p.start()
            p.join()
            res = queue.get()
def main(args):
    #global hl
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
    # environment setting
    cae = cae_identity
    mlp = MLP
    cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        system = standard_cpp_systems.PSOPTPendulum()
        dynamics = None
        enforce_bounds = None
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole':
        normalize = cart_pole.normalize
        unnormalize = cart_pole.unnormalize
        dynamics = cartpole.dynamics
        system = _sst_module.CartPole()
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = [0, 2]
        vel_indices = [1, 3]
    elif args.env_type == 'cartpole_obs_2':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP2
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = [0, 2]
        vel_indices = [1, 3]

    elif args.env_type == 'cartpole_obs_3':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP4
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = [0, 2]
        vel_indices = [1, 3]
        
    elif args.env_type == 'cartpole_obs_4_small':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])
    elif args.env_type == 'cartpole_obs_4_big':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])
    elif args.env_type == 'cartpole_obs_4_small_x_theta':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 1])
        vel_indices = np.array([2, 3])
    elif args.env_type == 'cartpole_obs_4_big_x_theta':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 1])
        vel_indices = np.array([2, 3])
    elif args.env_type == 'cartpole_obs_4_small_decouple_output':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])
    elif args.env_type == 'cartpole_obs_4_big_decouple_output':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])

        
        
    elif args.env_type == 'acrobot_obs':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_2':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP2
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_3':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_4':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_5':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_6':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP4
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_7':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP5
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_8':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20



    # set loss for mpnet
    if args.loss == 'mse':
        #mpnet.loss_f = nn.MSELoss()
        def mse_loss(y1, y2):
            l = (y1 - y2) ** 2
            l = torch.mean(l, dim=0)  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l
        loss_f_p = mse_loss
        loss_f_v = mse_loss

    elif args.loss == 'l1_smooth':
        #mpnet.loss_f = nn.SmoothL1Loss()
        def l1_smooth_loss(y1, y2):
            l1 = torch.abs(y1 - y2)
            cond = l1 < 1
            l = torch.where(cond, 0.5 * l1 ** 2, l1)
            l = torch.mean(l, dim=0)  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l
        loss_f_p = l1_smooth_loss
        loss_f_v = l1_smooth_loss

    elif args.loss == 'mse_decoupled':
        def mse_decoupled(y1, y2):
            # for angle terms, wrap it to -pi~pi
            l_0 = torch.abs(y1[:,0] - y2[:,0]) ** 2
            l_1 = torch.abs(y1[:,1] - y2[:,1]) ** 2
            l_2 = torch.abs(y1[:,2] - y2[:,2]) # angular dimension
            l_3 = torch.abs(y1[:,3] - y2[:,3]) ** 2
            cond = (l_2 > 1.0) * (l_2 <= 2.0)
            l_2 = torch.where(cond, 2*1.0-l_2, l_2)
            l_2 = l_2 ** 2
            l_0 = torch.mean(l_0)
            l_1 = torch.mean(l_1)
            l_2 = torch.mean(l_2)
            l_3 = torch.mean(l_3)
            return torch.stack([l_0, l_1, l_2, l_3])
        loss_f_p = mse_decoupled
        loss_f_v = mse_decoupled

    elif args.loss == 'l1_smooth_decoupled':
        
        # this only is for cartpole, need to adapt to other systems
        #TODO
        def l1_smooth_decoupled(y1, y2):
            # for angle terms, wrap it to -pi~pi
            l_0 = torch.abs(y1[:,0] - y2[:,0])
            l_1 = torch.abs(y1[:,1] - y2[:,1]) # angular dimension
            cond = (l_1 > 1.0) * (l_1 <= 2.0)
            l_1 = torch.where(cond, 2*1.0-l_1, l_1)
            
            # then change to l1_smooth_loss
            cond = l_0 < 1
            l_0 = torch.where(cond, 0.5 * l_0 ** 2, l_0)
            cond = l_1 < 1
            l_1 = torch.where(cond, 0.5 * l_1 ** 2, l_1)
            
            l_0 = torch.mean(l_0)
            l_1 = torch.mean(l_1)
            return torch.stack([l_0, l_1])
        def l1_smooth_loss(y1, y2):
            l1 = torch.abs(y1 - y2)
            cond = l1 < 1
            l = torch.where(cond, 0.5 * l1 ** 2, l1)
            l = torch.mean(l, dim=0)  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l
        loss_f_p = l1_smooth_decoupled
        loss_f_v = l1_smooth_loss


    if 'decouple_output' in args.env_type:
        print('mpnet using decoupled output')
        mpnet_pnet = KMPNet(args.total_input_size, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_p)
        mpnet_vnet = KMPNet(args.total_input_size, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_v)
    else:
        mpnet_pnet = KMPNet(args.total_input_size//2, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_p)
        mpnet_vnet = KMPNet(args.total_input_size//2, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_v)
        
    # load net
    # load previously trained model if start epoch > 0

    model_dir = args.model_dir
    if args.loss == 'mse':
        if args.multigoal == 0:
            model_dir = model_dir+args.env_type+"_lr%f_%s_step_%d/" % (args.learning_rate, args.opt, args.num_steps)
        else:
            model_dir = model_dir+args.env_type+"_lr%f_%s_step_%d_multigoal/" % (args.learning_rate, args.opt, args.num_steps)
    else:
        if args.multigoal == 0:
            model_dir = model_dir+args.env_type+"_lr%f_%s_loss_%s_step_%d/" % (args.learning_rate, args.opt, args.loss, args.num_steps)
        else:
            model_dir = model_dir+args.env_type+"_lr%f_%s_loss_%s_step_%d_multigoal/" % (args.learning_rate, args.opt, args.loss, args.num_steps)


    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_pnet_path='kmpnet_pnet_epoch_%d_direction_%d_step_%d.pkl' %(args.start_epoch, args.direction, args.num_steps)
    model_vnet_path='kmpnet_vnet_epoch_%d_direction_%d_step_%d.pkl' %(args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet_pnet, os.path.join(model_dir, model_pnet_path))
        load_net_state(mpnet_vnet, os.path.join(model_dir, model_vnet_path))

        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(os.path.join(model_dir, model_pnet_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet_pnet.cuda()
        mpnet_pnet.mlp.cuda()
        mpnet_pnet.encoder.cuda()

        mpnet_vnet.cuda()
        mpnet_vnet.mlp.cuda()
        mpnet_vnet.encoder.cuda()

        if args.opt == 'Adagrad':
            mpnet_pnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet_pnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet_pnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet_pnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)

            
        if args.opt == 'Adagrad':
            mpnet_vnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet_vnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet_vnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet_vnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)

            
            
        if args.start_epoch > 0:
            #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
            load_opt_state(mpnet_pnet, os.path.join(model_dir, model_path))
            load_opt_state(mpnet_vnet, os.path.join(model_dir, model_path))


    # load train and test data
    print('loading...')
    obs, waypoint_dataset, waypoint_targets, env_indices, \
    _, _, _, _ = data_loader.load_train_dataset(N=args.no_env, NP=args.no_motion_paths,
                                                data_folder=args.path_folder, obs_f=True,
                                                direction=args.direction,
                                                dynamics=dynamics, enforce_bounds=enforce_bounds,
                                                system=system, step_sz=step_sz,
                                                num_steps=args.num_steps, multigoal=args.multigoal)
    # randomize the dataset before training
    data=list(zip(waypoint_dataset,waypoint_targets,env_indices))
    random.shuffle(data)
    dataset,targets,env_indices=list(zip(*data))
    dataset = list(dataset)
    dataset = np.array(dataset)
    targets = np.array(targets)
    print(np.concatenate([pos_indices, pos_indices+args.total_input_size//2]))
    p_dataset = dataset[:, np.concatenate([pos_indices, pos_indices+args.total_input_size//2])]
    v_dataset = dataset[:, np.concatenate([vel_indices, vel_indices+args.total_input_size//2])]
    if 'decouple_output' in args.env_type:
        # only decouple output
        print('only decouple output but not input')
        p_dataset = dataset
        v_dataset = dataset
    print(p_dataset.shape)
    print(v_dataset.shape)
    
    
    
    p_targets = targets[:,pos_indices]
    v_targets = targets[:,vel_indices]   # this is only for cartpole
                                # TODO: add string for choosing env

    p_targets = list(p_targets)
    v_targets = list(v_targets)
    #targets = list(targets)
    env_indices = list(env_indices)
    dataset = np.array(dataset)
    #targets = np.array(targets)
    env_indices = np.array(env_indices)

    # use 5% as validation dataset
    val_len = int(len(dataset) * 0.05)
    val_p_dataset = p_dataset[-val_len:]
    val_v_dataset = v_dataset[-val_len:]
    val_p_targets = p_targets[-val_len:]
    val_v_targets = v_targets[-val_len:]
    val_env_indices = env_indices[-val_len:]

    p_dataset = p_dataset[:-val_len]
    v_dataset = v_dataset[:-val_len]
    p_targets = p_targets[:-val_len]
    v_targets = v_targets[:-val_len]
    env_indices = env_indices[:-val_len]

    # Train the Models
    print('training...')
    if args.loss == 'mse':
        if args.multigoal == 0:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, )
        else:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d_multigoal' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, )
    else:
        if args.multigoal == 0:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d_loss_%s' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, args.loss, )
        else:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d_loss_%s_multigoal' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, args.loss, )


    writer = SummaryWriter('./runs/'+writer_fname)
    record_i = 0
    val_record_i = 0
    p_loss_avg_i = 0
    p_val_loss_avg_i = 0
    p_loss_avg = 0.
    p_val_loss_avg = 0.
    v_loss_avg_i = 0
    v_val_loss_avg_i = 0
    v_loss_avg = 0.
    v_val_loss_avg = 0.

    loss_steps = 100  # record every 100 loss
    
    
    world_size = np.array(args.world_size)
    pos_world_size = list(world_size[pos_indices])
    vel_world_size = list(world_size[vel_indices])
    

    
    for epoch in range(args.start_epoch+1,args.num_epochs+1):
        print('epoch' + str(epoch))
        val_i = 0
        for i in range(0,len(p_dataset),args.batch_size):
            print('epoch: %d, training... path: %d' % (epoch, i+1))
            p_dataset_i = p_dataset[i:i+args.batch_size]
            v_dataset_i = v_dataset[i:i+args.batch_size]
            p_targets_i = p_targets[i:i+args.batch_size]
            v_targets_i = v_targets[i:i+args.batch_size]
            env_indices_i = env_indices[i:i+args.batch_size]
            # record
            p_bi = p_dataset_i.astype(np.float32)
            v_bi = v_dataset_i.astype(np.float32)
            print('p_bi shape:')
            print(p_bi.shape)
            print('v_bi shape:')
            print(v_bi.shape)
            p_bt = p_targets_i
            v_bt = v_targets_i
            p_bi = torch.FloatTensor(p_bi)
            v_bi = torch.FloatTensor(v_bi)
            p_bt = torch.FloatTensor(p_bt)
            v_bt = torch.FloatTensor(v_bt)

            # edit: disable this for investigation of the good weights for training, and for wrapping
            if 'decouple_output' in args.env_type:
                print('using normalizatino of decoupled output')
                # only decouple output but not input
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, args.world_size), normalize(v_bi, args.world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)
            else:
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, pos_world_size), normalize(v_bi, vel_world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)


            mpnet_pnet.zero_grad()
            mpnet_vnet.zero_grad()

            p_bi=to_var(p_bi)
            v_bi=to_var(v_bi)
            p_bt=to_var(p_bt)
            v_bt=to_var(v_bt)

            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            print('-------pnet-------')
            print('before training losses:')
            print(mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt))
            mpnet_pnet.step(p_bi, bobs, p_bt)
            print('after training losses:')
            print(mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt))
            p_loss = mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt)
            #update_line(hl, ax, [i//args.batch_size, loss.data.numpy()])
            p_loss_avg += p_loss.cpu().data
            p_loss_avg_i += 1
            
            print('-------vnet-------')
            print('before training losses:')
            print(mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt))
            mpnet_vnet.step(v_bi, bobs, v_bt)
            print('after training losses:')
            print(mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt))
            v_loss = mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt)
            #update_line(hl, ax, [i//args.batch_size, loss.data.numpy()])
            v_loss_avg += v_loss.cpu().data
            v_loss_avg_i += 1
            

            if p_loss_avg_i >= loss_steps:
                p_loss_avg = p_loss_avg / p_loss_avg_i
                writer.add_scalar('p_train_loss_0', p_loss_avg[0], record_i)
                writer.add_scalar('p_train_loss_1', p_loss_avg[1], record_i)

                v_loss_avg = v_loss_avg / v_loss_avg_i
                writer.add_scalar('v_train_loss_0', v_loss_avg[0], record_i)
                writer.add_scalar('v_train_loss_1', v_loss_avg[1], record_i)

                record_i += 1
                p_loss_avg = 0.
                p_loss_avg_i = 0

                v_loss_avg = 0.
                v_loss_avg_i = 0

                
            # validation
            # calculate the corresponding batch in val_dataset
            p_dataset_i = val_p_dataset[val_i:val_i+args.batch_size]
            v_dataset_i = val_v_dataset[val_i:val_i+args.batch_size]

            p_targets_i = val_p_targets[val_i:val_i+args.batch_size]
            v_targets_i = val_v_targets[val_i:val_i+args.batch_size]

            env_indices_i = val_env_indices[val_i:val_i+args.batch_size]
            val_i = val_i + args.batch_size
            if val_i > val_len:
                val_i = 0
            # record
            p_bi = p_dataset_i.astype(np.float32)
            v_bi = v_dataset_i.astype(np.float32)

            print('p_bi shape:')
            print(p_bi.shape)
            print('v_bi shape:')
            print(v_bi.shape)

            p_bt = p_targets_i
            v_bt = v_targets_i
            p_bi = torch.FloatTensor(p_bi)
            v_bi = torch.FloatTensor(v_bi)

            p_bt = torch.FloatTensor(p_bt)
            v_bt = torch.FloatTensor(v_bt)
            if 'decouple_output' in args.env_type:
                # only decouple output but not input
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, args.world_size), normalize(v_bi, args.world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)
            else:
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, pos_world_size), normalize(v_bi, vel_world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)
                
            p_bi=to_var(p_bi)
            v_bi=to_var(v_bi)
            p_bt=to_var(p_bt)
            v_bt=to_var(v_bt)

            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            print('-------pnet loss--------')
            p_loss = mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt)
            print('validation loss: ' % (p_loss.cpu().data))

            p_val_loss_avg += p_loss.cpu().data
            p_val_loss_avg_i += 1

            print('-------vnet loss--------')
            v_loss = mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt)
            print('validation loss: ' % (v_loss.cpu().data))

            v_val_loss_avg += v_loss.cpu().data
            v_val_loss_avg_i += 1

            
            if p_val_loss_avg_i >= loss_steps:
                p_val_loss_avg = p_val_loss_avg / p_val_loss_avg_i
                writer.add_scalar('p_val_loss_0', p_val_loss_avg[0], val_record_i)
                writer.add_scalar('p_val_loss_1', p_val_loss_avg[1], val_record_i)
                v_val_loss_avg = v_val_loss_avg / v_val_loss_avg_i
                writer.add_scalar('v_val_loss_0', v_val_loss_avg[0], val_record_i)
                writer.add_scalar('v_val_loss_1', v_val_loss_avg[1], val_record_i)

                
                val_record_i += 1
                p_val_loss_avg = 0.
                p_val_loss_avg_i = 0
                
                v_val_loss_avg = 0.
                v_val_loss_avg_i = 0

        # Save the models
        if epoch > 0 and epoch % 50 == 0:
            model_pnet_path='kmpnet_pnet_epoch_%d_direction_%d_step_%d.pkl' %(epoch, args.direction, args.num_steps)
            model_vnet_path='kmpnet_vnet_epoch_%d_direction_%d_step_%d.pkl' %(epoch, args.direction, args.num_steps)
            #save_state(mpnet, torch_seed, np_seed, py_seed, os.path.join(args.model_path,model_path))
            save_state(mpnet_pnet, torch_seed, np_seed, py_seed, os.path.join(model_dir,model_pnet_path))
            save_state(mpnet_vnet, torch_seed, np_seed, py_seed, os.path.join(model_dir,model_vnet_path))

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Example #13
0
def main(args):
    #global hl

    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
    # environment setting
    multigoal = False
    cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        system = standard_cpp_systems.PSOPTPendulum()
        dynamics = None
        enforce_bounds = None
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole':
        normalize = cart_pole.normalize
        unnormalize = cart_pole.unnormalize
        dynamics = cartpole.dynamics
        system = _sst_module.CartPole()
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.CartPole()
        dynamics = cartpole.dynamics
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        cae = cae_identity
        mlp = MLP
    elif args.env_type == 'cartpole_obs_4':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3_no_dropout
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        multigoal = False
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs_4_multigoal':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3_no_dropout
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        multigoal = True

        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'acrobot_obs':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_2':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP2
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_3':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_4':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_5':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_6':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP4
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_7':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP5
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_8':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20

    # set loss for mpnet
    if args.loss == 'mse':
        #mpnet.loss_f = nn.MSELoss()
        def mse_loss(y1, y2):
            l = (y1 - y2)**2
            l = torch.mean(
                l, dim=0
            )  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l

        loss_f = mse_loss

    elif args.loss == 'l1_smooth':
        #mpnet.loss_f = nn.SmoothL1Loss()
        def l1_smooth_loss(y1, y2):
            l1 = torch.abs(y1 - y2)
            cond = l1 < 1
            l = torch.where(cond, 0.5 * l1**2, l1)
            l = torch.mean(
                l, dim=0
            )  # sum alone the batch dimension, now the dimension is the same as input dimension

        loss_f = l1_smooth_loss

    elif args.loss == 'mse_decoupled':

        def mse_decoupled(y1, y2):
            # for angle terms, wrap it to -pi~pi
            l_0 = torch.abs(y1[:, 0] - y2[:, 0])**2
            l_1 = torch.abs(y1[:, 1] - y2[:, 1])**2
            l_2 = torch.abs(y1[:, 2] - y2[:, 2])  # angular dimension
            l_3 = torch.abs(y1[:, 3] - y2[:, 3])**2

            cond = (l_2 > 1.0) * (l_2 <= 2.0
                                  )  # np.pi after normalization is 1.0
            l_2 = torch.where(cond, 2.0 - l_2, l_2)
            l_2 = l_2**2
            l_0 = torch.mean(l_0)
            l_1 = torch.mean(l_1)
            l_2 = torch.mean(l_2)
            l_3 = torch.mean(l_3)
            return torch.stack([l_0, l_1, l_2, l_3])

        loss_f = mse_decoupled

    mpnet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp, loss_f)
    # load net
    # load previously trained model if start epoch > 0
    model_dir = args.model_dir
    model_dir = model_dir + 'cost_' + args.env_type + "_lr%f_%s_step_%d/" % (
        args.learning_rate, args.opt, args.num_steps)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = 'cost_kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
        args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet, os.path.join(model_dir, model_path))
        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(model_dir, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet.cuda()
        mpnet.mlp.cuda()
        mpnet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)
    if args.start_epoch > 0:
        #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
        load_opt_state(mpnet, os.path.join(model_dir, model_path))

    # load train and test data
    print('loading...')
    obs, cost_dataset, cost_targets, env_indices, \
    _, _, _, _ = data_loader.load_train_dataset_cost(N=args.no_env, NP=args.no_motion_paths,
                                                data_folder=args.path_folder, obs_f=True,
                                                direction=args.direction,
                                                dynamics=dynamics, enforce_bounds=enforce_bounds,
                                                system=system, step_sz=step_sz, num_steps=args.num_steps,
                                                multigoal=multigoal)
    # randomize the dataset before training
    data = list(zip(cost_dataset, cost_targets, env_indices))
    random.shuffle(data)
    dataset, targets, env_indices = list(zip(*data))
    dataset = list(dataset)
    targets = list(targets)
    env_indices = list(env_indices)
    dataset = np.array(dataset)
    targets = np.array(targets)
    env_indices = np.array(env_indices)

    # use 5% as validation dataset
    val_len = int(len(dataset) * 0.05)
    val_dataset = dataset[-val_len:]
    val_targets = targets[-val_len:]
    val_env_indices = env_indices[-val_len:]

    dataset = dataset[:-val_len]
    targets = targets[:-val_len]
    env_indices = env_indices[:-val_len]

    # Train the Models
    print('training...')
    writer_fname = 'cost_%s_%f_%s_direction_%d_step_%d' % (
        args.env_type, args.learning_rate, args.opt, args.direction,
        args.num_steps)
    writer = SummaryWriter('./runs/' + writer_fname)
    record_i = 0
    val_record_i = 0
    loss_avg_i = 0
    val_loss_avg_i = 0
    loss_avg = 0.
    val_loss_avg = 0.
    loss_steps = 100  # record every 100 loss
    for epoch in range(args.start_epoch + 1, args.num_epochs + 1):
        print('epoch' + str(epoch))
        val_i = 0
        for i in range(0, len(dataset), args.batch_size):
            print('epoch: %d, training... path: %d' % (epoch, i + 1))
            dataset_i = dataset[i:i + args.batch_size]
            targets_i = targets[i:i + args.batch_size]
            env_indices_i = env_indices[i:i + args.batch_size]
            # record
            bi = dataset_i.astype(np.float32)
            print('bi shape:')
            print(bi.shape)
            bt = targets_i
            bi = torch.FloatTensor(bi)
            bt = torch.FloatTensor(bt)
            bi = normalize(bi, args.world_size)
            mpnet.zero_grad()
            bi = to_var(bi)
            bt = to_var(bt)
            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            print('before training losses:')
            print(mpnet.loss(mpnet(bi, bobs), bt))
            mpnet.step(bi, bobs, bt)
            print('after training losses:')
            print(mpnet.loss(mpnet(bi, bobs), bt))
            loss = mpnet.loss(mpnet(bi, bobs), bt)
            #update_line(hl, ax, [i//args.batch_size, loss.data.numpy()])
            loss_avg += loss.cpu().data
            loss_avg_i += 1
            if loss_avg_i >= loss_steps:
                loss_avg = loss_avg / loss_avg_i
                writer.add_scalar('train_loss', loss_avg, record_i)
                record_i += 1
                loss_avg = 0.
                loss_avg_i = 0

            # validation
            # calculate the corresponding batch in val_dataset
            dataset_i = val_dataset[val_i:val_i + args.batch_size]
            targets_i = val_targets[val_i:val_i + args.batch_size]
            env_indices_i = val_env_indices[val_i:val_i + args.batch_size]
            val_i = val_i + args.batch_size
            if val_i > val_len:
                val_i = 0
            # record
            bi = dataset_i.astype(np.float32)
            print('bi shape:')
            print(bi.shape)
            bt = targets_i
            bi = torch.FloatTensor(bi)
            bt = torch.FloatTensor(bt)
            bi = normalize(bi, args.world_size)
            bi = to_var(bi)
            bt = to_var(bt)
            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            loss = mpnet.loss(mpnet(bi, bobs), bt)
            print('validation loss: %f' % (loss.cpu().data))

            val_loss_avg += loss.cpu().data
            val_loss_avg_i += 1
            if val_loss_avg_i >= loss_steps:
                val_loss_avg = val_loss_avg / val_loss_avg_i
                writer.add_scalar('val_loss', val_loss_avg, val_record_i)
                val_record_i += 1
                val_loss_avg = 0.
                val_loss_avg_i = 0
        # Save the models
        if epoch > 0 and epoch % 50 == 0:
            model_path = 'cost_kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
                epoch, args.direction, args.num_steps)
            #save_state(mpnet, torch_seed, np_seed, py_seed, os.path.join(args.model_path,model_path))
            save_state(mpnet, torch_seed, np_seed, py_seed,
                       os.path.join(model_dir, model_path))
    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Example #14
0
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    #torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    #if torch.cuda.is_available():
    #    torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        obs_file = None
        obc_file = None
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        normalize = cartpole.normalize
        unnormalize = cartpole.unnormalize
        obs_file = None
        obc_file = None
        dynamics = cartpole.dynamics
        jax_dynamics = cartpole.jax_dynamics
        enforce_bounds = cartpole.enforce_bounds
        cae = CAE_acrobot_voxel_2d
        mlp = mlp_acrobot.MLP
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'cartpole')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs':
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 200, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_2':
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_3':
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_5':
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0
    elif args.env_type == 'acrobot_obs_6':
        obs_file = None
        obc_file = None
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0
        obs_width = 6.0
    elif args.env_type == 'acrobot_obs_6':
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0
        obs_width = 6.0

    elif args.env_type == 'acrobot_obs_8':
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        #num_steps = 21
        num_steps = 21  #args.num_steps*2
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        #traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init:
        #def cem_trajopt(x0, x1, step_sz, num_steps, x_init, u_init, t_init):
        #    u, t = acrobot_obs.trajopt(x0, x1, 500, num_steps, step_sz*1, step_sz*(num_steps-1), x_init, u_init, t_init)
        #    xs, us, dts, valid = propagate(x0, u, t, dynamics=dynamics, enforce_bounds=enforce_bounds, IsInCollision=lambda x: False, system=system, step_sz=step_sz)
        #    return xs, us, dts
        #traj_opt = cem_trajopt
        obs_width = 6.0
        goal_S0 = np.diag([1., 1., 0, 0])
        goal_rho0 = 1.0

    if args.env_type == 'pendulum':
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole_obs':
        #system = standard_cpp_systems.RectangleObs(obs[i], 4.0, 'cartpole')
        step_sz = 0.002
        num_steps = 20
        goal_S0 = np.identity(4)
        goal_rho0 = 1.0
    elif args.env_type in [
            'acrobot_obs', 'acrobot_obs_2', 'acrobot_obs_3', 'acrobot_obs_4',
            'acrobot_obs_8'
    ]:
        #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
        obs_width = 6.0
        step_sz = 0.02
        num_steps = 21
        goal_radius = 2.0
        random_seed = 0
        delta_near = 0.1
        delta_drain = 0.05

    # load previously trained model if start epoch > 0
    #model_path='kmpnet_epoch_%d_direction_0_step_%d.pkl' %(args.start_epoch, args.num_steps)
    mlp_path = os.path.join(
        os.getcwd() + '/c++/',
        'acrobot_obs_MLP_lr0.010000_epoch_2850_step_20.pt')
    encoder_path = os.path.join(
        os.getcwd() + '/c++/',
        'acrobot_obs_encoder_lr0.010000_epoch_2850_step_20.pt')
    cost_mlp_path = os.path.join(
        os.getcwd() + '/c++/',
        'costnet_acrobot_obs_8_MLP_epoch_300_step_20.pt')
    cost_encoder_path = os.path.join(
        os.getcwd() + '/c++/',
        'costnet_acrobot_obs_8_encoder_epoch_300_step_20.pt')

    print('mlp_path:')
    print(mlp_path)

    #####################################################
    def plan_one_path(obs_i, obs, obc, start_state, goal_state,
                      goal_inform_state, max_iteration, out_queue):
        if args.env_type == 'pendulum':
            system = standard_cpp_systems.PSOPTPendulum()
            bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
            step_sz = 0.002
            num_steps = 20
            traj_opt = lambda x0, x1: bvp_solver.solve(x0, x1, 200, num_steps,
                                                       1, 20, step_sz)

        elif args.env_type == 'cartpole_obs':
            #system = standard_cpp_systems.RectangleObs(obs[i], 4.0, 'cartpole')
            system = _sst_module.CartPole()
            bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
            step_sz = 0.002
            num_steps = 20
            traj_opt = lambda x0, x1, x_init, u_init, t_init: bvp_solver.solve(
                x0, x1, 200, num_steps, step_sz * 1, step_sz * 50, x_init,
                u_init, t_init)
            goal_S0 = np.identity(4)
            goal_rho0 = 1.0
        elif args.env_type in [
                'acrobot_obs', 'acrobot_obs_2', 'acrobot_obs_3',
                'acrobot_obs_4', 'acrobot_obs_8'
        ]:
            #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
            obs_width = 6.0
            psopt_system = _sst_module.PSOPTAcrobot()
            propagate_system = standard_cpp_systems.RectangleObs(
                obs, 6., 'acrobot')
            distance_computer = propagate_system.distance_computer()
            #distance_computer = _sst_module.euclidean_distance(np.array(propagate_system.is_circular_topology()))
            step_sz = 0.02
            num_steps = 21
            goal_radius = 2.0
            random_seed = 0
            delta_near = .1
            delta_drain = 0.05
        #print('creating planner...')
        planner = vis_planners.DeepSMPWrapper(mlp_path, encoder_path,
                                              cost_mlp_path, cost_encoder_path,
                                              args.bvp_iter, num_steps,
                                              step_sz, propagate_system, 3)
        # generate a path by using SST to plan for some maximal iterations
        time0 = time.time()
        #print('obc:')
        #print(obc.shape)
        #print(delta_near)
        #print(delta_drain)
        #print('start_state:')
        #print(start_state)
        #print('goal_state:')
        #print(goal_state)
        res_x, res_u, res_t = planner.plan("sst", args.plan_type, propagate_system, psopt_system, obc.flatten(), start_state, goal_inform_state, goal_inform_state, \
                                goal_radius, max_iteration, distance_computer, \
                                args.delta_near, args.delta_drain)

        #res_x, res_u, res_t = planner.plan("sst", propagate_system, psopt_system, obc.flatten(), start_state, goal_state, goal_inform_state, \
        #                        goal_radius, max_iteration, propagate_system.distance_computer(), \
        #                        delta_near, delta_drain)
        plan_time = time.time() - time0
        """
        # visualization
        plt.ion()
        fig = plt.figure()
        ax = fig.add_subplot(111)
        #ax.set_autoscale_on(True)
        ax.set_xlim(-np.pi, np.pi)
        ax.set_ylim(-np.pi, np.pi)
        hl, = ax.plot([], [], 'b')
        #hl_real, = ax.plot([], [], 'r')
        hl_for, = ax.plot([], [], 'g')
        hl_back, = ax.plot([], [], 'r')
        hl_for_mpnet, = ax.plot([], [], 'lightgreen')
        hl_back_mpnet, = ax.plot([], [], 'salmon')

        print(obs)
        def update_line(h, ax, new_data):
            new_data = wrap_angle(new_data, propagate_system)
            h.set_data(np.append(h.get_xdata(), new_data[0]), np.append(h.get_ydata(), new_data[1]))
            #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
            #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

        def remove_last_k(h, ax, k):
            h.set_data(h.get_xdata()[:-k], h.get_ydata()[:-k])

        def draw_update_line(ax):
            #ax.relim()
            #ax.autoscale_view()
            fig.canvas.draw()
            fig.canvas.flush_events()
            #plt.show()

        def wrap_angle(x, system):
            circular = system.is_circular_topology()
            res = np.array(x)
            for i in range(len(x)):
                if circular[i]:
                    # use our previously saved version
                    res[i] = x[i] - np.floor(x[i] / (2*np.pi))*(2*np.pi)
                    if res[i] > np.pi:
                        res[i] = res[i] - 2*np.pi
            return res
        dtheta = 0.1
        feasible_points = []
        infeasible_points = []
        imin = 0
        imax = int(2*np.pi/dtheta)


        for i in range(imin, imax):
            for j in range(imin, imax):
                x = np.array([dtheta*i-np.pi, dtheta*j-np.pi, 0., 0.])
                if IsInCollision(x, obs_i):
                    infeasible_points.append(x)
                else:
                    feasible_points.append(x)
        feasible_points = np.array(feasible_points)
        infeasible_points = np.array(infeasible_points)
        print('feasible points')
        print(feasible_points)
        print('infeasible points')
        print(infeasible_points)
        ax.scatter(feasible_points[:,0], feasible_points[:,1], c='yellow')
        ax.scatter(infeasible_points[:,0], infeasible_points[:,1], c='pink')

        if len(res_x) != 0:
            xs_to_plot = np.array(res_x)
            for i in range(len(xs_to_plot)):
                xs_to_plot[i] = wrap_angle(xs_to_plot[i], propagate_system)
            ax.scatter(xs_to_plot[:,0], xs_to_plot[:,1], c='orange')
            print('solution: x')
            print(res_x)
            print('solution: u')
            print(res_u)
            print('solution: t')
            print(res_t)
            # draw start and goal
            ax.scatter(start_state[0], goal_state[0], marker='X')
            draw_update_line(ax)
            plt.waitforbuttonpress()

        
        
        #im = planner.visualize_nodes(propagate_system)
        #sec = input('Let us wait for user input')
        #show_image_opencv(im, "planning_tree", wait=True)
        """
        print('plan time: %fs' % (plan_time))
        if len(res_x) == 0:
            print('failed.')
            out_queue.put(-1)
        else:
            print('path succeeded.')
            out_queue.put(plan_time)

    ####################################################################################

    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N,
                                                       args.seen_NP,
                                                       args.data_folder, obs_f,
                                                       args.seen_s,
                                                       args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            args.unseen_N, args.unseen_NP, args.data_folder, obs_f,
            args.unseen_s, args.unseen_sp)
    # test
    # testing

    queue = Queue(1)
    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.

    obc, obs, paths, sgs, path_lengths, controls, costs = seen_test_data
    obc = obc.astype(np.float32)
    #obc = torch.from_numpy(obc)
    #if torch.cuda.is_available():
    #    obc = obc.cuda()

    plan_res = []
    plan_times = []
    plan_res_all = []
    for i in range(len(paths)):
        new_obs_i = []
        obs_i = obs[i]
        plan_res_env = []
        plan_time_env = []
        for k in range(len(obs_i)):
            obs_pt = []
            obs_pt.append(obs_i[k][0] - obs_width / 2)
            obs_pt.append(obs_i[k][1] - obs_width / 2)
            obs_pt.append(obs_i[k][0] - obs_width / 2)
            obs_pt.append(obs_i[k][1] + obs_width / 2)
            obs_pt.append(obs_i[k][0] + obs_width / 2)
            obs_pt.append(obs_i[k][1] + obs_width / 2)
            obs_pt.append(obs_i[k][0] + obs_width / 2)
            obs_pt.append(obs_i[k][1] - obs_width / 2)
            new_obs_i.append(obs_pt)
        obs_i = new_obs_i
        #print(obs_i)
        for j in range(len(paths[i])):
            start_state = sgs[i][j][0]
            goal_inform_state = paths[i][j][-1]
            goal_state = sgs[i][j][1]
            print('environment: %d/%d, path: %d/%d' %
                  (i + 1, len(paths), j + 1, len(paths[i])))

            p = Process(target=plan_one_path,
                        args=(obs_i, obs[i], obc[i], start_state, goal_state,
                              goal_inform_state, 1000, queue))
            #plan_one_path(obs_i, obs[i], obc[i], start_state, goal_state, goal_inform_state, 500, queue)
            p.start()
            p.join()
            res = queue.get()
            if res == -1:
                plan_res_env.append(0)
                plan_res_all.append(0)
            else:
                plan_res_env.append(1)
                plan_times.append(res)
                plan_res_all.append(1)
            print('average accuracy up to now: %f' %
                  (np.array(plan_res_all).flatten().mean()))
            print('plan average time: %f' % (np.array(plan_times).mean()))
            print('plan time std: %f' % (np.array(plan_times).std()))
        plan_res.append(plan_res_env)
    print('plan accuracy: %f' % (np.array(plan_res).flatten().mean()))
    print('plan average time: %f' % (np.array(plan_times).mean()))
    print('plan time std: %f' % (np.array(plan_times).std()))
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    #torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    #if torch.cuda.is_available():
    #    torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'acrobot_obs':
        obs_file = None
        obc_file = None
        #cpp_propagator = _sst_module.SystemPropagator()
        #dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        obs_f = True
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        goal_S0 = np.diag([1.,1.,0,0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(x0, x1, 200, num_steps, step_sz*1, step_sz*(num_steps-1), x_init, u_init, t_init)
        goal_S0 = np.diag([1.,1.,0,0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

        


    if args.env_type == 'pendulum':
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole_obs':
        #system = standard_cpp_systems.RectangleObs(obs[i], 4.0, 'cartpole')
        step_sz = 0.002
        num_steps = 20
        goal_S0 = np.identity(4)
        goal_rho0 = 1.0
    elif args.env_type in ['acrobot_obs','acrobot_obs_2', 'acrobot_obs_3', 'acrobot_obs_4', 'acrobot_obs_8']:
        #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
        obs_width = 6.0
        step_sz = 0.02
        num_steps = 21
        goal_radius=10.0
        random_seed=0
        delta_near=1.0
        delta_drain=0.5


    # load previously trained model if start epoch > 0
    #model_path='kmpnet_epoch_%d_direction_0_step_%d.pkl' %(args.start_epoch, args.num_steps)
    mlp_path = os.path.join(os.getcwd()+'/c++/','acrobot_obs_MLP_epoch_5000.pt')
    encoder_path = os.path.join(os.getcwd()+'/c++/','acrobot_obs_encoder_epoch_5000.pt')
    cost_mlp_path = os.path.join(os.getcwd()+'/c++/','costnet_acrobot_obs_MLP_epoch_800_step_10.pt')
    cost_encoder_path = os.path.join(os.getcwd()+'/c++/','costnet_acrobot_obs_encoder_epoch_800_step_10.pt')
    
    print('mlp_path:')
    print(mlp_path)
    #####################################################
    def plan_one_path(obs_i, obs, obc, start_state, goal_state, goal_inform_state, cost_i, max_iteration, out_queue):
        if args.env_type in ['acrobot_obs','acrobot_obs_2', 'acrobot_obs_3', 'acrobot_obs_4', 'acrobot_obs_8']:
            #system = standard_cpp_systems.RectangleObs(obs[i], 6.0, 'acrobot')
            obs_width = 6.0
            psopt_system = _sst_module.PSOPTAcrobot()
            propagate_system = standard_cpp_systems.RectangleObs(obs, 6., 'acrobot')
            distance_computer = propagate_system.distance_computer()
            #distance_computer = _sst_module.euclidean_distance(np.array(propagate_system.is_circular_topology()))
            step_sz = 0.02
            num_steps = 21
            goal_radius=2.0
            random_seed=0
            delta_near=1.0
            delta_drain=0.5
            device = 3
        #print('creating planner...')
        planner = vis_planners.DeepSMPWrapper(mlp_path, encoder_path, cost_mlp_path, cost_encoder_path, \
                                              200, num_steps, step_sz, propagate_system, device)
        #cost_threshold = cost_i * 1.1
        cost_threshold = 100000000.
        # generate a path by using SST to plan for some maximal iterations
        time0 = time.time()
        res_x, res_u, res_t = planner.plan_tree_SMP_cost_gradient("sst", propagate_system, psopt_system, obc.flatten(), start_state, goal_inform_state, goal_inform_state, \
                                goal_radius, max_iteration, distance_computer, \
                                delta_near, delta_drain, cost_threshold, 15)
        plan_time = time.time() - time0

        """
        # visualization
        plt.ion()
        fig = plt.figure()
        ax = fig.add_subplot(111)
        #ax.set_autoscale_on(True)
        ax.set_xlim(-np.pi, np.pi)
        ax.set_ylim(-np.pi, np.pi)
        hl, = ax.plot([], [], 'b')
        #hl_real, = ax.plot([], [], 'r')
        hl_for, = ax.plot([], [], 'g')
        hl_back, = ax.plot([], [], 'r')
        hl_for_mpnet, = ax.plot([], [], 'lightgreen')
        hl_back_mpnet, = ax.plot([], [], 'salmon')

        #print(obs)
        def update_line(h, ax, new_data):
            new_data = wrap_angle(new_data, propagate_system)
            h.set_data(np.append(h.get_xdata(), new_data[0]), np.append(h.get_ydata(), new_data[1]))
            #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
            #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

        def remove_last_k(h, ax, k):
            h.set_data(h.get_xdata()[:-k], h.get_ydata()[:-k])

        def draw_update_line(ax):
            #ax.relim()
            #ax.autoscale_view()
            fig.canvas.draw()
            fig.canvas.flush_events()
            #plt.show()

        def wrap_angle(x, system):
            circular = system.is_circular_topology()
            res = np.array(x)
            for i in range(len(x)):
                if circular[i]:
                    # use our previously saved version
                    res[i] = x[i] - np.floor(x[i] / (2*np.pi))*(2*np.pi)
                    if res[i] > np.pi:
                        res[i] = res[i] - 2*np.pi
            return res
        dtheta = 0.1
        feasible_points = []
        infeasible_points = []
        imin = 0
        imax = int(2*np.pi/dtheta)


        for i in range(imin, imax):
            for j in range(imin, imax):
                x = np.array([dtheta*i-np.pi, dtheta*j-np.pi, 0., 0.])
                if IsInCollision(x, obs_i):
                    infeasible_points.append(x)
                else:
                    feasible_points.append(x)
        feasible_points = np.array(feasible_points)
        infeasible_points = np.array(infeasible_points)
        print('feasible points')
        print(feasible_points)
        print('infeasible points')
        print(infeasible_points)
        ax.scatter(feasible_points[:,0], feasible_points[:,1], c='yellow')
        ax.scatter(infeasible_points[:,0], infeasible_points[:,1], c='pink')
        #for i in range(len(data)):
        #    update_line(hl, ax, data[i])
        draw_update_line(ax)
        #state_t = start_state
                
        if len(res_u):
            # propagate data
            p_start = res_x[0]
            detail_paths = [p_start]
            detail_controls = []
            detail_costs = []
            state = [p_start]
            control = []
            cost = []
            for k in range(len(res_u)):
                #state_i.append(len(detail_paths)-1)
                max_steps = int(res_t[k]/step_sz)
                accum_cost = 0.
                #print('p_start:')
                #print(p_start)
                #print('data:')
                #print(paths[i][j][k])
                # modify it because of small difference between data and actual propagation
                p_start = res_x[k]
                state[-1] = res_x[k]
                for step in range(1,max_steps+1):
                    p_start = dynamics(p_start, res_u[k], step_sz)
                    p_start = enforce_bounds(p_start)
                    detail_paths.append(p_start)
                    accum_cost += step_sz
                    if (step % 1 == 0) or (step == max_steps):
                        state.append(p_start)
                        #print('control')
                        #print(controls[i][j])
                        cost.append(accum_cost)
                        accum_cost = 0.
            #print('p_start:')
            #print(p_start)
            #print('data:')
            #print(paths[i][j][-1])
            state[-1] = res_x[-1]
            
            
            
            xs_to_plot = np.array(state)
            for i in range(len(xs_to_plot)):
                xs_to_plot[i] = wrap_angle(xs_to_plot[i], propagate_system)
                if IsInCollision(xs_to_plot[i], obs_i):
                    print('in collision')
            ax.scatter(xs_to_plot[:,0], xs_to_plot[:,1], c='green')
            # draw start and goal
            #ax.scatter(start_state[0], goal_state[0], marker='X')
            draw_update_line(ax)
            plt.waitforbuttonpress()
        """
        
        #im = planner.visualize_nodes(propagate_system)
        #sec = input('Let us wait for user input')
        #show_image_opencv(im, "planning_tree", wait=True)
        
        # validate if the path contains collision
        """
        if len(res_u):
            # propagate data
            p_start = res_x[0]
            detail_paths = [p_start]
            detail_controls = []
            detail_costs = []
            state = [p_start]
            control = []
            cost = []
            for k in range(len(res_u)):
                #state_i.append(len(detail_paths)-1)
                max_steps = int(res_t[k]/step_sz)
                accum_cost = 0.
                #print('p_start:')
                #print(p_start)
                #print('data:')
                #print(paths[i][j][k])
                # modify it because of small difference between data and actual propagation
                p_start = res_x[k]
                state[-1] = res_x[k]
                for step in range(1,max_steps+1):
                    p_start = dynamics(p_start, res_u[k], step_sz)
                    p_start = enforce_bounds(p_start)
                    detail_paths.append(p_start)
                    accum_cost += step_sz
                    if (step % 1 == 0) or (step == max_steps):
                        state.append(p_start)
                        #print('control')
                        #print(controls[i][j])
                        cost.append(accum_cost)
                        accum_cost = 0.
                        # check collision for the new state
                        assert not IsInCollision(p_start, obs_i)
                        
            #print('p_start:')
            #print(p_start)
            #print('data:')
            #print(paths[i][j][-1])
            state[-1] = res_x[-1]
        # validation end
        """
        
        print('plan time: %fs' % (plan_time))
        if len(res_x) == 0:
            print('failed.')
            out_queue.put(-1)
        else:
            print('path succeeded.')
            print('cost: %f' % (np.sum(res_t)))
            print('cost_threshold: %f' % (cost_threshold))
            print('data cost: %f' % (cost_i))
            out_queue.put(plan_time)
    ####################################################################################



    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N, args.seen_NP,
                                  args.data_folder, obs_f, args.seen_s, args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(args.unseen_N, args.unseen_NP,
                                  args.data_folder, obs_f, args.unseen_s, args.unseen_sp)
    # test
    # testing

    queue = Queue(1)
    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.

    obc, obs, paths, sgs, path_lengths, controls, costs = seen_test_data
    obc = obc.astype(np.float32)
    #obc = torch.from_numpy(obc)
    #if torch.cuda.is_available():
    #    obc = obc.cuda()
    
    plan_res = []
    plan_times = []
    plan_res_all = []
    for i in range(len(paths)):
        new_obs_i = []
        obs_i = obs[i]
        plan_res_env = []
        plan_time_env = []
        for k in range(len(obs_i)):
            obs_pt = []
            obs_pt.append(obs_i[k][0]-obs_width/2)
            obs_pt.append(obs_i[k][1]-obs_width/2)
            obs_pt.append(obs_i[k][0]-obs_width/2)
            obs_pt.append(obs_i[k][1]+obs_width/2)
            obs_pt.append(obs_i[k][0]+obs_width/2)
            obs_pt.append(obs_i[k][1]+obs_width/2)
            obs_pt.append(obs_i[k][0]+obs_width/2)
            obs_pt.append(obs_i[k][1]-obs_width/2)
            new_obs_i.append(obs_pt)
        obs_i = new_obs_i
        #print(obs_i)
        for j in range(len(paths[i])):
            start_state = sgs[i][j][0]
            goal_inform_state = paths[i][j][-1]
            goal_state = sgs[i][j][1]
            cost_i = costs[i][j].sum()
            #cost_i = 100000000.
            print('environment: %d/%d, path: %d/%d' % (i+1, len(paths), j+1, len(paths[i])))
            p = Process(target=plan_one_path, args=(obs_i, obs[i], obc[i], start_state, goal_state, goal_inform_state, cost_i, 300000, queue))
            #plan_one_path(obs_i, obs[i], obc[i], start_state, goal_state, goal_inform_state, cost_i, 300000, queue)
            p.start()
            p.join()
            res = queue.get()
            if res == -1:
                plan_res_env.append(0)
                plan_res_all.append(0)
            else:
                plan_res_env.append(1)
                plan_times.append(res)
                plan_res_all.append(1)
            print('average accuracy up to now: %f' % (np.array(plan_res_all).flatten().mean()))
            print('plan average time: %f' % (np.array(plan_times).mean()))
            print('plan time std: %f' % (np.array(plan_times).std()))
        plan_res.append(plan_res_env)
    print('plan accuracy: %f' % (np.array(plan_res).flatten().mean()))
    print('plan average time: %f' % (np.array(plan_times).mean()))
    print('plan time std: %f' % (np.array(plan_times).std()))
Example #16
0
                data_path, data_control, data_cost, dynamics, enforce_bounds,
                system, step_sz)
            correct_path_file = dir + 'path_%d' % (j) + ".pkl"
            correct_control_file = dir + 'control_%d' % (j) + ".pkl"
            correct_cost_file = dir + 'cost_%d' % (j) + ".pkl"

            # store the corrected file
            file = open(correct_path_file, 'wb')
            pickle.dump(np.array(correct_path), file)
            file = open(correct_control_file, 'wb')
            pickle.dump(np.array(correct_control), file)
            file = open(correct_cost_file, 'wb')
            pickle.dump(np.array(correct_cost), file)


cpp_propagator = _sst_module.SystemPropagator()
acrobot_system = _sst_module.PSOPTAcrobot()
acrobot_dynamics = lambda x, u, t: cpp_propagator.propagate(
    acrobot_system, x, u, t)
cartpole_system = _sst_module.PSOPTCartPole()
cartpole_dynamics = lambda x, u, t: cpp_propagator.propagate(
    cartpole_system, x, u, t)

# use the following if don't want to modify the original dataset, but create new ones
#correct_dataset(N=10, NP=1000, data_folder='../data/acrobot_obs/', obs_f=True, direction=0, dynamics=acrobot_dynamics, enforce_bounds=None, system=None, step_sz=0.02)

#correct_dataset(N=10, NP=1000, data_folder='../data/cartpole_obs/', obs_f=True, direction=0, dynamics=cartpole_dynamics, enforce_bounds=None, system=None, step_sz=0.002)

# use the following if want to change the original dataset
rename_dataset(N=10,
               NP=1000,
Example #17
0
def main(args):
    # set seed
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        obs_file = None
        obc_file = None
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        obs_file = None
        obc_file = None
        obs_f = True
        obs_width = 4.0
        step_sz = 0.002
        psopt_system = _sst_module.PSOPTCartPole()
        cpp_propagator = _sst_module.SystemPropagator()

        #system = standard_cpp_systems.RectangleObs(obs, 4., 'cartpole')
        dynamics = lambda x, u, t: cpp_propagator.propagate(psopt_system, x, u, t)
        cpp_state_validator = lambda x, obs: cpp_propagator.cartpole_validate(x, obs, obs_width)
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'cartpole')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs':
        obs_file = None
        obc_file = None

        obs_f = True
        obs_width = 6.0

        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'acrobot')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)

    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N, args.seen_NP,
                                  args.data_folder, obs_f, args.seen_s, args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(args.unseen_N, args.unseen_NP,
                                  args.data_folder, obs_f, args.unseen_s, args.unseen_sp)
    # test
    # testing


    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.
    
    # find path
    

    # randomly pick up a point in the data, and find similar data in the dataset
    # plot the next point
    obc, obs, paths, sgs, path_lengths, controls, costs = seen_test_data
    for envi in range(10):
        for pathi in range(20):
            obs_i = obs[envi]
            new_obs_i = []
            obs_i = obs[envi]
            plan_res_path = []
            plan_time_path = []
            plan_cost_path = []
            data_cost_path = []
                        
            
            for k in range(len(obs_i)):
                obs_pt = []
                obs_pt.append(obs_i[k][0]-obs_width/2)
                obs_pt.append(obs_i[k][1]-obs_width/2)
                obs_pt.append(obs_i[k][0]-obs_width/2)
                obs_pt.append(obs_i[k][1]+obs_width/2)
                obs_pt.append(obs_i[k][0]+obs_width/2)
                obs_pt.append(obs_i[k][1]+obs_width/2)
                obs_pt.append(obs_i[k][0]+obs_width/2)
                obs_pt.append(obs_i[k][1]-obs_width/2)
                new_obs_i.append(obs_pt)
            obs_i = new_obs_i

            """
            # visualization
            plt.ion()
            fig = plt.figure()
            ax = fig.add_subplot(111)
            #ax.set_autoscale_on(True)
            ax.set_xlim(-30, 30)
            ax.set_ylim(-np.pi, np.pi)
            hl, = ax.plot([], [], 'b')
            #hl_real, = ax.plot([], [], 'r')
            hl_for, = ax.plot([], [], 'g')
            hl_back, = ax.plot([], [], 'r')
            hl_for_mpnet, = ax.plot([], [], 'lightgreen')
            hl_back_mpnet, = ax.plot([], [], 'salmon')

            #print(obs)
            def update_line(h, ax, new_data):
                new_data = wrap_angle(new_data, propagate_system)
                h.set_data(np.append(h.get_xdata(), new_data[0]), np.append(h.get_ydata(), new_data[1]))
                #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
                #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

            def remove_last_k(h, ax, k):
                h.set_data(h.get_xdata()[:-k], h.get_ydata()[:-k])

            def draw_update_line(ax):
                #ax.relim()
                #ax.autoscale_view()
                fig.canvas.draw()
                fig.canvas.flush_events()
                #plt.show()

            def wrap_angle(x, system):
                circular = system.is_circular_topology()
                res = np.array(x)
                for i in range(len(x)):
                    if circular[i]:
                        # use our previously saved version
                        res[i] = x[i] - np.floor(x[i] / (2*np.pi))*(2*np.pi)
                        if res[i] > np.pi:
                            res[i] = res[i] - 2*np.pi
                return res
            dx = 1
            dtheta = 0.1
            feasible_points = []
            infeasible_points = []
            imin = 0
            imax = int(2*30./dx)
            jmin = 0
            jmax = int(2*np.pi/dtheta)
            for i in range(imin, imax):
                for j in range(jmin, jmax):
                    x = np.array([dx*i-30, 0., dtheta*j-np.pi, 0.])
                    if IsInCollision(x, obs_i):
                        infeasible_points.append(x)
                        print('state:', x)
                        print('python collison')
                        print("cpp collision result: ", cpp_state_validator(x, obs[envi]))
                        
                    else:
                        feasible_points.append(x)
                        print('state:', x)
                        print('python not in collison')
                        print("cpp collision result: ", cpp_state_validator(x, obs[envi]))
            feasible_points = np.array(feasible_points)
            infeasible_points = np.array(infeasible_points)
            print('feasible points')
            print(feasible_points)
            print('infeasible points')
            print(infeasible_points)
            ax.scatter(feasible_points[:,0], feasible_points[:,2], c='yellow')
            ax.scatter(infeasible_points[:,0], infeasible_points[:,2], c='pink')
            #for i in range(len(data)):
            #    update_line(hl, ax, data[i])
            draw_update_line(ax)
            #state_t = start_state
            """
            xs = paths[envi][pathi]
            us = controls[envi][pathi]
            ts = costs[envi][pathi]
            # propagate data
            p_start = xs[0]
            detail_paths = [p_start]
            detail_controls = []
            detail_costs = []
            state = [p_start]
            control = []
            cost = []
            for k in range(len(us)):
                #state_i.append(len(detail_paths)-1)
                max_steps = int(np.round(ts[k]/step_sz))
                accum_cost = 0.
                print('p_start:')
                print(p_start)
                print('data:')
                print(paths[envi][pathi][k])
                
                
                # comment this out to test corrected data versus previous data
                #p_start = xs[k]
                #state[-1] = xs[k]
                
                for step in range(1,max_steps+1):
                    p_start = dynamics(p_start, us[k], step_sz)
                    p_start = enforce_bounds(p_start)
                    detail_paths.append(p_start)
                    accum_cost += step_sz
                    if (step % 1 == 0) or (step == max_steps):
                        state.append(p_start)
                        #print('control')
                        #print(controls[i][j])
                        cost.append(accum_cost)
                        accum_cost = 0.
                        
                        print('state:', p_start)
                        print('python IsInCollison result: ', IsInCollision(p_start, obs_i))
                        print("cpp validate result: ", cpp_state_validator(p_start, obs[envi]))
                        assert not IsInCollision(p_start, obs_i)

            print('p_start:')
            print(p_start)
            print('data:')
            print(paths[envi][pathi][-1])
            #state[-1] = xs[-1]

            """