Пример #1
0
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        IsInCollision = pendulum.IsInCollision
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        obs_file = None
        obc_file = None
        dynamics = pendulum.dynamics
        jax_dynamics = pendulum.jax_dynamics
        enforce_bounds = pendulum.enforce_bounds
        cae = cae_identity
        mlp = MLP
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        IsInCollision = cartpole.IsInCollision
        normalize = cartpole.normalize
        unnormalize = cartpole.unnormalize
        obs_file = None
        obc_file = None
        dynamics = cartpole.dynamics
        jax_dynamics = cartpole.jax_dynamics
        enforce_bounds = cartpole.enforce_bounds
        cae = CAE_acrobot_voxel_2d
        mlp = mlp_acrobot.MLP
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'cartpole')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs':
        system = _sst_module.PSOPTAcrobot()
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d
        mlp = mlp_acrobot.MLP
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'acrobot')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs_2':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d_2
        mlp = mlp_acrobot.MLP2
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'acrobot')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs_8':
        system = _sst_module.PSOPTAcrobot()
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)

        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'acrobot')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)

    jac_A = jax.jacfwd(jax_dynamics, argnums=0)
    jac_B = jax.jacfwd(jax_dynamics, argnums=1)
    mpNet0 = KMPNet(args.total_input_size, args.AE_input_size,
                    args.mlp_input_size, args.output_size, cae, mlp)
    mpNet1 = KMPNet(args.total_input_size, args.AE_input_size,
                    args.mlp_input_size, args.output_size, cae, mlp)

    # load previously trained model if start epoch > 0
    model_path = 'kmpnet_epoch_%d_direction_0.pkl' % (args.start_epoch)
    if args.start_epoch > 0:
        load_net_state(mpNet0, os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(args.model_path, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)
    if torch.cuda.is_available():
        mpNet0.cuda()
        mpNet0.mlp.cuda()
        mpNet0.encoder.cuda()
        if args.opt == 'Adagrad':
            mpNet0.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpNet0.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpNet0.set_opt(torch.optim.SGD,
                           lr=args.learning_rate,
                           momentum=0.9)
    if args.start_epoch > 0:
        load_opt_state(mpNet0, os.path.join(args.model_path, model_path))

    # load previously trained model if start epoch > 0
    model_path = 'kmpnet_epoch_%d_direction_1.pkl' % (args.start_epoch)
    if args.start_epoch > 0:
        load_net_state(mpNet1, os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(args.model_path, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)
    if torch.cuda.is_available():
        mpNet1.cuda()
        mpNet1.mlp.cuda()
        mpNet1.encoder.cuda()
        if args.opt == 'Adagrad':
            mpNet1.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpNet1.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpNet1.set_opt(torch.optim.SGD,
                           lr=args.learning_rate,
                           momentum=0.9)
    if args.start_epoch > 0:
        load_opt_state(mpNet1, os.path.join(args.model_path, model_path))

    _, waypoint_dataset, waypoint_targets, _, _, _, _, _ = data_loader.load_train_dataset(
        1, 2, args.data_folder, obs_f, 1, dynamics, enforce_bounds, system,
        0.02, 20)

    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N,
                                                       args.seen_NP,
                                                       args.data_folder, obs_f,
                                                       args.seen_s,
                                                       args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            args.unseen_N, args.unseen_NP, args.data_folder, obs_f,
            args.unseen_s, args.unseen_sp)
    # test
    # testing

    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.
    T = 1
    # unnormalize function
    normalize_func = lambda x: normalize(x, args.world_size)
    unnormalize_func = lambda x: unnormalize(x, args.world_size)
    # seen
    if args.seen_N > 0:
        time_file = os.path.join(
            args.model_path, 'time_seen_epoch_%d_mlp.p' % (args.start_epoch))
        fes_path_, valid_path_ = eval_tasks_mpnet(
            mpNet0, mpNet1, args.env_type, seen_test_data, args.model_path,
            'seen', normalize_func, unnormalize_func, dynamics, jac_A, jac_B,
            enforce_bounds, IsInCollision)
    # unseen
    if args.unseen_N > 0:
        time_file = os.path.join(
            args.model_path, 'time_unseen_epoch_%d_mlp.p' % (args.start_epoch))
        fes_path_, valid_path_ = eval_tasks_mpnet(
            mpNet0, mpNet1, args.env_type, unseen_test_data, args.model_path,
            'unseen', normalize_func, unnormalize_func, dynamics, jac_A, jac_B,
            enforce_bounds, IsInCollision)
Пример #2
0
def main(args):
    # set seed
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        obs_file = None
        obc_file = None
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        obs_file = None
        obc_file = None
        obs_f = True
        obs_width = 4.0
        step_sz = 0.002
        psopt_system = _sst_module.PSOPTCartPole()
        cpp_propagator = _sst_module.SystemPropagator()

        #system = standard_cpp_systems.RectangleObs(obs, 4., 'cartpole')
        dynamics = lambda x, u, t: cpp_propagator.propagate(
            psopt_system, x, u, t)

        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 100
    elif args.env_type == 'cartpole_obs_2':
        obs_file = None
        obc_file = None
        obs_f = True
        obs_width = 4.0
        step_sz = 0.002
        psopt_system = _sst_module.PSOPTCartPole()
        cpp_propagator = _sst_module.SystemPropagator()

        #system = standard_cpp_systems.RectangleObs(obs, 4., 'cartpole')
        dynamics = lambda x, u, t: cpp_propagator.propagate(
            psopt_system, x, u, t)

        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP2
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 100
    elif args.env_type == 'cartpole_obs_3':
        obs_file = None
        obc_file = None
        obs_f = True
        obs_width = 4.0
        step_sz = 0.002
        psopt_system = _sst_module.PSOPTCartPole()
        cpp_propagator = _sst_module.SystemPropagator()

        #system = standard_cpp_systems.RectangleObs(obs, 4., 'cartpole')
        dynamics = lambda x, u, t: cpp_propagator.propagate(
            psopt_system, x, u, t)

        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP4
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 200
    elif args.env_type == 'cartpole_obs_4':
        obs_file = None
        obc_file = None
        obs_f = True
        obs_width = 4.0
        step_sz = 0.002
        psopt_system = _sst_module.PSOPTCartPole()
        cpp_propagator = _sst_module.SystemPropagator()

        #system = standard_cpp_systems.RectangleObs(obs, 4., 'cartpole')
        dynamics = lambda x, u, t: cpp_propagator.propagate(
            psopt_system, x, u, t)

        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 200

    elif args.env_type == 'acrobot_obs':
        obs_file = None
        obc_file = None

        obs_f = True
        obs_width = 6.0

        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'acrobot')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)

    mpnet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp, None)
    # load net
    # load previously trained model if start epoch > 0
    model_dir = args.model_dir
    if args.loss == 'mse':
        if args.multigoal == 0:
            model_dir = model_dir + args.env_type + "_lr%f_%s_step_%d/" % (
                args.learning_rate, args.opt, args.num_steps)
        else:
            model_dir = model_dir + args.env_type + "_lr%f_%s_step_%d_multigoal/" % (
                args.learning_rate, args.opt, args.num_steps)
    else:
        if args.multigoal == 0:
            model_dir = model_dir + args.env_type + "_lr%f_%s_loss_%s_step_%d/" % (
                args.learning_rate, args.opt, args.loss, args.num_steps)
        else:
            model_dir = model_dir + args.env_type + "_lr%f_%s_loss_%s_step_%d_multigoal/" % (
                args.learning_rate, args.opt, args.loss, args.num_steps)

    print(model_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    model_path = 'kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
        args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet, os.path.join(model_dir, model_path))
        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(model_dir, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet.cuda()
        mpnet.mlp.cuda()
        mpnet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)
    if args.start_epoch > 0:
        #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
        load_opt_state(mpnet, os.path.join(model_dir, model_path))

    mpnet.eval()

    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N,
                                                       args.seen_NP,
                                                       args.data_folder, obs_f,
                                                       args.seen_s,
                                                       args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            args.unseen_N, args.unseen_NP, args.data_folder, obs_f,
            args.unseen_s, args.unseen_sp)
    # test
    # testing

    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.

    # find path

    plt.ion()
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.set_autoscale_on(True)
    hl, = ax.plot([], [], 'b')

    #hl_real, = ax.plot([], [], 'r')
    def update_line(h, ax, new_data):
        h.set_data(np.append(h.get_xdata(), new_data[0]),
                   np.append(h.get_ydata(), new_data[1]))
        #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
        #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

    def draw_update_line(ax):
        ax.relim()
        ax.autoscale_view()
        fig.canvas.draw()
        fig.canvas.flush_events()

    # randomly pick up a point in the data, and find similar data in the dataset
    # plot the next point
    obc, obs, paths, sgs, path_lengths, controls, costs = seen_test_data
    for envi in range(2):
        for pathi in range(10):
            obs_i = obs[envi]
            new_obs_i = []
            obs_i = obs[envi]
            plan_res_path = []
            plan_time_path = []
            plan_cost_path = []
            data_cost_path = []
            for k in range(len(obs_i)):
                obs_pt = []
                obs_pt.append(obs_i[k][0] - obs_width / 2)
                obs_pt.append(obs_i[k][1] - obs_width / 2)
                obs_pt.append(obs_i[k][0] - obs_width / 2)
                obs_pt.append(obs_i[k][1] + obs_width / 2)
                obs_pt.append(obs_i[k][0] + obs_width / 2)
                obs_pt.append(obs_i[k][1] + obs_width / 2)
                obs_pt.append(obs_i[k][0] + obs_width / 2)
                obs_pt.append(obs_i[k][1] - obs_width / 2)
                new_obs_i.append(obs_pt)
            obs_i = new_obs_i

            # visualization
            plt.ion()
            fig = plt.figure()
            ax = fig.add_subplot(121)
            ax_vel = fig.add_subplot(122)
            #ax.set_autoscale_on(True)
            ax.set_xlim(-30, 30)
            ax.set_ylim(-np.pi, np.pi)
            ax_vel.set_xlim(-40, 40)
            ax_vel.set_ylim(-2, 2)

            hl, = ax.plot([], [], 'b')
            #hl_real, = ax.plot([], [], 'r')
            hl_for, = ax.plot([], [], 'g')
            hl_back, = ax.plot([], [], 'r')
            hl_for_mpnet, = ax.plot([], [], 'lightgreen')
            hl_back_mpnet, = ax.plot([], [], 'salmon')

            #print(obs)
            def update_line(h, ax, new_data):
                new_data = wrap_angle(new_data, propagate_system)
                h.set_data(np.append(h.get_xdata(), new_data[0]),
                           np.append(h.get_ydata(), new_data[1]))
                #h.set_xdata(np.append(h.get_xdata(), new_data[0]))
                #h.set_ydata(np.append(h.get_ydata(), new_data[1]))

            def remove_last_k(h, ax, k):
                h.set_data(h.get_xdata()[:-k], h.get_ydata()[:-k])

            def draw_update_line(ax):
                #ax.relim()
                #ax.autoscale_view()
                fig.canvas.draw()
                fig.canvas.flush_events()
                #plt.show()

            def wrap_angle(x, system):
                circular = system.is_circular_topology()
                res = np.array(x)
                for i in range(len(x)):
                    if circular[i]:
                        # use our previously saved version
                        res[i] = x[i] - np.floor(x[i] /
                                                 (2 * np.pi)) * (2 * np.pi)
                        if res[i] > np.pi:
                            res[i] = res[i] - 2 * np.pi
                return res

            dx = 1
            dtheta = 0.1
            feasible_points = []
            infeasible_points = []
            imin = 0
            imax = int(2 * 30. / dx)
            jmin = 0
            jmax = int(2 * np.pi / dtheta)

            for i in range(imin, imax):
                for j in range(jmin, jmax):
                    x = np.array([dx * i - 30, 0., dtheta * j - np.pi, 0.])
                    if IsInCollision(x, obs_i):
                        infeasible_points.append(x)
                    else:
                        feasible_points.append(x)
            feasible_points = np.array(feasible_points)
            infeasible_points = np.array(infeasible_points)
            print('feasible points')
            print(feasible_points)
            print('infeasible points')
            print(infeasible_points)
            ax.scatter(feasible_points[:, 0],
                       feasible_points[:, 2],
                       c='yellow')
            ax.scatter(infeasible_points[:, 0],
                       infeasible_points[:, 2],
                       c='pink')
            #for i in range(len(data)):
            #    update_line(hl, ax, data[i])
            draw_update_line(ax)
            #state_t = start_state

            xs = paths[envi][pathi]
            us = controls[envi][pathi]
            ts = costs[envi][pathi]
            # propagate data
            p_start = xs[0]
            detail_paths = [p_start]
            detail_controls = []
            detail_costs = []
            state = [p_start]
            control = []
            cost = []
            for k in range(len(us)):
                #state_i.append(len(detail_paths)-1)
                max_steps = int(ts[k] / step_sz)
                accum_cost = 0.
                #print('p_start:')
                #print(p_start)
                #print('data:')
                #print(paths[i][j][k])
                # modify it because of small difference between data and actual propagation
                p_start = xs[k]
                state[-1] = xs[k]
                for step in range(1, max_steps + 1):
                    p_start = dynamics(p_start, us[k], step_sz)
                    p_start = enforce_bounds(p_start)
                    detail_paths.append(p_start)
                    accum_cost += step_sz
                    if (step % 1 == 0) or (step == max_steps):
                        state.append(p_start)
                        #print('control')
                        #print(controls[i][j])
                        cost.append(accum_cost)
                        accum_cost = 0.
            #print('p_start:')
            #print(p_start)
            #print('data:')
            #print(paths[i][j][-1])
            state[-1] = xs[-1]
            #print(len(state))

            xs_to_plot = np.array(state)
            for i in range(len(xs_to_plot)):
                xs_to_plot[i] = wrap_angle(xs_to_plot[i], psopt_system)
            ax.scatter(xs_to_plot[:, 0], xs_to_plot[:, 2], c='green')
            # draw start and goal
            #ax.scatter(start_state[0], goal_state[0], marker='X')
            draw_update_line(ax)
            ax_vel.scatter(xs_to_plot[:, 1],
                           xs_to_plot[:, 3],
                           c='green',
                           s=0.1)
            draw_update_line(ax_vel)

            plt.waitforbuttonpress()

            # visualize mPNet path
            mpnet_paths = []
            state = xs[0]
            #for k in range(int(len(xs_to_plot)/args.num_steps)):
            for k in range(50):
                mpnet_paths.append(state)
                bi = np.concatenate([state, xs[-1]])
                bi = np.array([bi])
                bi = torch.from_numpy(bi).type(torch.FloatTensor)
                print(bi)
                bi = normalize(bi, args.world_size)
                bi = to_var(bi)
                if obc is None:
                    bobs = None
                else:
                    bobs = np.array([obc[envi]]).astype(np.float32)
                    print(bobs.shape)
                    bobs = torch.FloatTensor(bobs)
                    bobs = to_var(bobs)
                bt = mpnet(bi, bobs).cpu()
                bt = unnormalize(bt, args.world_size)
                bt = bt.detach().numpy()
                print(bt.shape)
                state = bt[0]

            print(mpnet_paths)
            xs_to_plot = np.array(mpnet_paths)
            print(len(xs_to_plot))
            for i in range(len(xs_to_plot)):
                xs_to_plot[i] = wrap_angle(xs_to_plot[i], psopt_system)
            ax.scatter(xs_to_plot[:, 0], xs_to_plot[:, 2], c='lightgreen')
            # draw start and goal
            #ax.scatter(start_state[0], goal_state[0], marker='X')
            draw_update_line(ax)
            ax_vel.scatter(xs_to_plot[:, 1], xs_to_plot[:, 3], c='lightgreen')
            draw_update_line(ax_vel)
            plt.waitforbuttonpress()
Пример #3
0
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        IsInCollision = pendulum.IsInCollision
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        obs_file = None
        obc_file = None
        cae = cae_identity
        mlp = MLP
        system = standard_cpp_systems.PSOPTPendulum()
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
        max_iter = 100
        min_time_steps = 10
        max_time_steps = 200
        integration_step = 0.002
        goal_radius = 0.1
        random_seed = 0
        sst_delta_near = 0.05
        sst_delta_drain = 0.02
        vel_idx = [1]

    mpNet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp)
    # load previously trained model if start epoch > 0
    model_path = 'kmpnet_epoch_%d.pkl' % (args.start_epoch)
    if args.start_epoch > 0:
        load_net_state(mpNet, os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(args.model_path, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)
    if torch.cuda.is_available():
        mpNet.cuda()
        mpNet.mlp.cuda()
        mpNet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpNet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpNet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpNet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
    if args.start_epoch > 0:
        load_opt_state(mpNet, os.path.join(args.model_path, model_path))

    # load train and test data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(
            N=args.seen_N,
            NP=args.seen_NP,
            s=args.seen_s,
            sp=args.seen_sp,
            p_folder=args.path_folder,
            obs_f=obs_file,
            obc_f=obc_file)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            N=args.unseen_N,
            NP=args.unseen_NP,
            s=args.unseen_s,
            sp=args.unseen_sp,
            p_folder=args.path_folder,
            obs_f=obs_file,
            obc_f=obc_file)
    # test
    # testing
    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.
    T = 1

    obc, obs, paths, path_lengths = seen_test_data
    if obs is not None:
        obs = obs.astype(np.float32)
        obs = torch.from_numpy(obs)
    fes_env = []  # list of list
    valid_env = []
    time_env = []
    time_total = []
    normalize_func = lambda x: normalize(x, args.world_size)
    unnormalize_func = lambda x: unnormalize(x, args.world_size)

    for i in range(len(paths)):
        time_path = []
        fes_path = []  # 1 for feasible, 0 for not feasible
        valid_path = []  # if the feasibility is valid or not
        # save paths to different files, indicated by i
        # feasible paths for each env
        suc_n = 0
        for j in range(len(paths[0])):
            plt.ion()
            fig = plt.figure()
            ax = fig.add_subplot(111)
            ax.set_autoscale_on(True)
            hl, = ax.plot([], [], 'black')
            hl_real, = ax.plot([], [], 'yellow')

            time0 = time.time()
            time_norm = 0.
            fp = 0  # indicator for feasibility
            print("step: i=" + str(i) + " j=" + str(j))
            p1_ind = 0
            p2_ind = 0
            p_ind = 0
            if path_lengths[i][j] == 0:
                # invalid, feasible = 0, and path count = 0
                fp = 0
                valid_path.append(0)
            if path_lengths[i][j] > 0:
                fp = 0
                valid_path.append(1)
                path = [paths[i][j][0], paths[i][j][path_lengths[i][j] - 1]]
                start = paths[i][j][0]
                end = paths[i][j][path_lengths[i][j] - 1]
                #start[1] = 0.
                #end[1] = 0.
                # plot the entire path
                #plt.plot(paths[i][j][:,0], paths[i][j][:,1])
                control = []
                time_step = []
                MAX_NEURAL_REPLAN = 11
                if obs is None:
                    obs_i = None
                    obc_i = None
                else:
                    obs_i = obs[i]
                    obc_i = obc[i]
            for k in range(path_lengths[i][j]):
                update_line(hl, ax, fig, paths[i][j][k])
            print('created RRT')
            # Run planning and print out solution is some statistics every few iterations.
            time0 = time.time()
            start = paths[i][j][0]
            #end = paths[i][j][path_lengths[i][j]-1]
            new_sample = start
            print(new_sample)
            ax.scatter(new_sample[0], new_sample[1], c='r')
            ax.scatter(end[0], end[1], c='g')
            for iteration in range(max_iter):
                clear_line(hl_real, ax, fig)
                #hl_real, = ax.plot([], [], 'yellow')
                ip1 = np.concatenate([new_sample, end])
                np.expand_dims(ip1, 0)
                #ip1=torch.cat((obs,start,goal)).unsqueeze(0)
                time0 = time.time()
                ip1 = normalize_func(ip1)
                ip1 = torch.FloatTensor(ip1)
                time_norm += time.time() - time0
                ip1 = to_var(ip1)
                if obs is not None:
                    obs = torch.FloatTensor(obs).unsqueeze(0)
                    obs = to_var(obs)
                sample = mpNet(ip1, obs).squeeze(0)
                # unnormalize to world size
                sample = sample.data.cpu().numpy()
                time0 = time.time()
                sample = unnormalize_func(sample)
                ax.scatter(sample[0], sample[1], c='b')
                plt.pause(0.01)

                steer, steer_state, steer_control, steer_time_step = plan_general.steerTo(
                    bvp_solver, start, sample, None, None, step_sz=0.02)
                for k in range(len(steer_state)):
                    update_line(hl_real, ax, fig, steer_state[k])
            plt.waitforbuttonpress()
Пример #4
0
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        IsInCollision = pendulum.IsInCollision
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        obs_file = None
        obc_file = None
        dynamics = pendulum.dynamics
        jax_dynamics = pendulum.jax_dynamics
        enforce_bounds = pendulum.enforce_bounds
        cae = cae_identity
        mlp = MLP
        obs_f = False
        #system = standard_cpp_systems.PSOPTPendulum()
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
    elif args.env_type == 'cartpole_obs':
        IsInCollision = cartpole.IsInCollision
        normalize = cartpole.normalize
        unnormalize = cartpole.unnormalize
        obs_file = None
        obc_file = None
        dynamics = cartpole.dynamics
        jax_dynamics = cartpole.jax_dynamics
        enforce_bounds = cartpole.enforce_bounds
        cae = CAE_acrobot_voxel_2d
        mlp = mlp_acrobot.MLP
        obs_f = True
        #system = standard_cpp_systems.RectangleObs(obs_list, args.obs_width, 'cartpole')
        #bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
    elif args.env_type == 'acrobot_obs':
        IsInCollision = acrobot_obs.IsInCollision
        #IsInCollision = lambda x, obs: False
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        xdot = acrobot_obs.dynamics
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d
        mlp = mlp_acrobot.MLP
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 50, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_2':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        xdot = acrobot_obs.dynamics
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d_2
        mlp = mlp_acrobot.MLP2
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_3':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        xdot = acrobot_obs.dynamics
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_2
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_5':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        xdot = acrobot_obs.dynamics
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d_3
        mlp = mlp_acrobot.MLP
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0
    elif args.env_type == 'acrobot_obs_6':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        xdot = acrobot_obs.dynamics
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d_3
        mlp = mlp_acrobot.MLP4
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_6':
        IsInCollision = acrobot_obs.IsInCollision
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        xdot = acrobot_obs.dynamics
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        mlp = mlp_acrobot.MLP5
        cae = CAE_acrobot_voxel_2d_3
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        num_steps = 21
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        goal_S0 = np.diag([1., 1., 0, 0])
        #goal_S0 = np.identity(4)
        goal_rho0 = 1.0

    elif args.env_type == 'acrobot_obs_8':
        IsInCollision = acrobot_obs.IsInCollision
        #IsInCollision = lambda x, obs: False
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        obs_file = None
        obc_file = None
        system = _sst_module.PSOPTAcrobot()
        cpp_propagator = _sst_module.SystemPropagator()
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        xdot = acrobot_obs.dynamics
        jax_dynamics = acrobot_obs.jax_dynamics
        enforce_bounds = acrobot_obs.enforce_bounds
        cae = CAE_acrobot_voxel_2d_3
        mlp = mlp_acrobot.MLP6
        obs_f = True
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 4, 1, 0)
        step_sz = 0.02
        #num_steps = 21
        num_steps = 21  #args.num_steps*2
        traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init: bvp_solver.solve(
            x0, x1, 400, num_steps, step_sz * 1, step_sz *
            (num_steps - 1), x_init, u_init, t_init)
        #traj_opt = lambda x0, x1, step_sz, num_steps, x_init, u_init, t_init:
        #def cem_trajopt(x0, x1, step_sz, num_steps, x_init, u_init, t_init):
        #    u, t = acrobot_obs.trajopt(x0, x1, 500, num_steps, step_sz*1, step_sz*(num_steps-1), x_init, u_init, t_init)
        #    xs, us, dts, valid = propagate(x0, u, t, dynamics=dynamics, enforce_bounds=enforce_bounds, IsInCollision=lambda x: False, system=system, step_sz=step_sz)
        #    return xs, us, dts
        #traj_opt = cem_trajopt
        goal_S0 = np.diag([1., 1., 0, 0])
        goal_rho0 = 1.0

    mpNet0 = KMPNet(args.total_input_size, args.AE_input_size,
                    args.mlp_input_size, args.output_size, cae, mlp)
    mpNet1 = KMPNet(args.total_input_size, args.AE_input_size,
                    args.mlp_input_size, args.output_size, cae, mlp)

    # load previously trained model if start epoch > 0
    #model_path='kmpnet_epoch_%d_direction_0_step_%d.pkl' %(args.start_epoch, args.num_steps)
    model_path = 'kmpnet_epoch_%d_direction_0.pkl' % (args.start_epoch)
    if args.start_epoch > 0:
        load_net_state(mpNet0, os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(args.model_path, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)
    if torch.cuda.is_available():
        mpNet0.cuda()
        mpNet0.mlp.cuda()
        mpNet0.encoder.cuda()
        if args.opt == 'Adagrad':
            mpNet0.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpNet0.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpNet0.set_opt(torch.optim.SGD,
                           lr=args.learning_rate,
                           momentum=0.9)
    if args.start_epoch > 0:
        load_opt_state(mpNet0, os.path.join(args.model_path, model_path))

    # load previously trained model if start epoch > 0
    #model_path='kmpnet_epoch_%d_direction_1_step_%d.pkl' %(args.start_epoch, args.num_steps)
    model_path = 'kmpnet_epoch_%d_direction_1.pkl' % (args.start_epoch)
    if args.start_epoch > 0:
        load_net_state(mpNet1, os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(args.model_path, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)
    if torch.cuda.is_available():
        mpNet1.cuda()
        mpNet1.mlp.cuda()
        mpNet1.encoder.cuda()
        if args.opt == 'Adagrad':
            mpNet1.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpNet1.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpNet1.set_opt(torch.optim.SGD,
                           lr=args.learning_rate,
                           momentum=0.9)
    if args.start_epoch > 0:
        load_opt_state(mpNet1, os.path.join(args.model_path, model_path))

    # define informer
    circular = system.is_circular_topology()

    def informer(env, x0, xG, direction):
        x0_x = torch.from_numpy(x0.x).type(torch.FloatTensor)
        xG_x = torch.from_numpy(xG.x).type(torch.FloatTensor)
        x0_x = normalize_func(x0_x)
        xG_x = normalize_func(xG_x)
        if torch.cuda.is_available():
            x0_x = x0_x.cuda()
            xG_x = xG_x.cuda()
        if direction == 0:
            x = torch.cat([x0_x, xG_x], dim=0)
            mpNet = mpNet0
            if torch.cuda.is_available():
                x = x.cuda()
            next_state = mpNet(x.unsqueeze(0), env.unsqueeze(0)).cpu().data
            next_state = unnormalize_func(next_state).numpy()[0]
            delta_x = next_state - x0.x
            # can be either clockwise or counterclockwise, take shorter one
            for i in range(len(delta_x)):
                if circular[i]:
                    delta_x[i] = delta_x[i] - np.floor(
                        delta_x[i] / (2 * np.pi)) * (2 * np.pi)
                    if delta_x[i] > np.pi:
                        delta_x[i] = delta_x[i] - 2 * np.pi
                    # randomly pick either direction
                    rand_d = np.random.randint(2)
                    if rand_d < 1 and np.abs(delta_x[i]) >= np.pi * 0.5:
                        if delta_x[i] > 0.:
                            delta_x[i] = delta_x[i] - 2 * np.pi
                        if delta_x[i] <= 0.:
                            delta_x[i] = delta_x[i] + 2 * np.pi

            res = Node(x0.x + delta_x)
            cov = np.diag([0.02, 0.02, 0.02, 0.02])
            #mean = next_state
            #next_state = np.random.multivariate_normal(mean=next_state,cov=cov)
            mean = np.zeros(next_state.shape)
            rand_x_init = np.random.multivariate_normal(mean=mean,
                                                        cov=cov,
                                                        size=num_steps)
            rand_x_init[0] = rand_x_init[0] * 0.
            rand_x_init[-1] = rand_x_init[-1] * 0.

            x_init = np.linspace(x0.x, x0.x + delta_x, num_steps) + rand_x_init
            ## TODO: : change this to general case
            u_init_i = np.random.uniform(low=[-4.],
                                         high=[4],
                                         size=(num_steps, 1))
            u_init = u_init_i
            #u_init_i = control[max_d_i]
            cost_i = (num_steps - 1) * step_sz  #TOEDIT
            #u_init = np.repeat(u_init_i, num_steps, axis=0).reshape(-1,len(u_init_i))
            #u_init = u_init + np.random.normal(scale=1., size=u_init.shape)
            t_init = np.linspace(0, cost_i, num_steps)
            """
            print('init:')
            print('x_init:')
            print(x_init)
            print('u_init:')
            print(u_init)
            print('t_init:')
            print(t_init)
            print('xw:')
            print(next_state)
            """
        else:
            x = torch.cat([x0_x, xG_x], dim=0)
            mpNet = mpNet1
            next_state = mpNet(x.unsqueeze(0), env.unsqueeze(0)).cpu().data
            next_state = unnormalize_func(next_state).numpy()[0]
            delta_x = next_state - x0.x
            # can be either clockwise or counterclockwise, take shorter one
            for i in range(len(delta_x)):
                if circular[i]:
                    delta_x[i] = delta_x[i] - np.floor(
                        delta_x[i] / (2 * np.pi)) * (2 * np.pi)
                    if delta_x[i] > np.pi:
                        delta_x[i] = delta_x[i] - 2 * np.pi
                    # randomly pick either direction
                    rand_d = np.random.randint(2)
                    if rand_d < 1 and np.abs(delta_x[i]) >= np.pi * 0.5:
                        if delta_x[i] > 0.:
                            delta_x[i] = delta_x[i] - 2 * np.pi
                        elif delta_x[i] <= 0.:
                            delta_x[i] = delta_x[i] + 2 * np.pi
            #next_state = state[max_d_i] + delta_x
            next_state = x0.x + delta_x
            res = Node(next_state)
            # initial: from max_d_i to max_d_i+1
            x_init = np.linspace(next_state, x0.x, num_steps) + rand_x_init
            # action: copy over to number of steps
            u_init_i = np.random.uniform(low=[-4.],
                                         high=[4],
                                         size=(num_steps, 1))
            u_init = u_init_i
            cost_i = (num_steps - 1) * step_sz
            #u_init = np.repeat(u_init_i, num_steps, axis=0).reshape(-1,len(u_init_i))
            #u_init = u_init + np.random.normal(scale=1., size=u_init.shape)
            t_init = np.linspace(0, cost_i, num_steps)
        return res, x_init, u_init, t_init

    def init_informer(env, x0, xG, direction):
        if direction == 0:
            next_state = xG.x
            delta_x = next_state - x0.x

            # can be either clockwise or counterclockwise, take shorter one
            for i in range(len(delta_x)):
                if circular[i]:
                    delta_x[i] = delta_x[i] - np.floor(
                        delta_x[i] / (2 * np.pi)) * (2 * np.pi)
                    if delta_x[i] > np.pi:
                        delta_x[i] = delta_x[i] - 2 * np.pi
                    # randomly pick either direction
                    rand_d = np.random.randint(2)
                    #print('inside init_informer')
                    #print('delta_x[%d]: %f' % (i, delta_x[i]))
                    if rand_d < 1 and np.abs(delta_x[i]) >= np.pi * 0.9:
                        if delta_x[i] > 0.:
                            delta_x[i] = delta_x[i] - 2 * np.pi
                        if delta_x[i] <= 0.:
                            delta_x[i] = delta_x[i] + 2 * np.pi
            res = Node(next_state)
            cov = np.diag([0.02, 0.02, 0.02, 0.02])
            #mean = next_state
            #next_state = np.random.multivariate_normal(mean=next_state,cov=cov)
            mean = np.zeros(next_state.shape)
            rand_x_init = np.random.multivariate_normal(mean=mean,
                                                        cov=cov,
                                                        size=num_steps)
            rand_x_init[0] = rand_x_init[0] * 0.
            rand_x_init[-1] = rand_x_init[-1] * 0.

            x_init = np.linspace(x0.x, x0.x + delta_x, num_steps) + rand_x_init
            ## TODO: : change this to general case
            u_init_i = np.random.uniform(low=[-4.],
                                         high=[4],
                                         size=(num_steps, 1))
            u_init = u_init_i
            #u_init_i = control[max_d_i]
            #cost_i = 10*step_sz
            cost_i = (num_steps - 1) * step_sz

            #u_init = np.repeat(u_init_i, num_steps, axis=0).reshape(-1,len(u_init_i))
            #u_init = u_init + np.random.normal(scale=1., size=u_init.shape)
            t_init = np.linspace(0, cost_i, num_steps)

        else:
            next_state = xG.x
            delta_x = x0.x - next_state
            # can be either clockwise or counterclockwise, take shorter one
            for i in range(len(delta_x)):
                if circular[i]:
                    delta_x[i] = delta_x[i] - np.floor(
                        delta_x[i] / (2 * np.pi)) * (2 * np.pi)
                    if delta_x[i] > np.pi:
                        delta_x[i] = delta_x[i] - 2 * np.pi
                    # randomly pick either direction
                    rand_d = np.random.randint(2)
                    if rand_d < 1 and np.abs(delta_x[i]) >= np.pi * 0.5:
                        if delta_x[i] > 0.:
                            delta_x[i] = delta_x[i] - 2 * np.pi
                        elif delta_x[i] <= 0.:
                            delta_x[i] = delta_x[i] + 2 * np.pi
            #next_state = state[max_d_i] + delta_x
            res = Node(next_state)
            # initial: from max_d_i to max_d_i+1
            x_init = np.linspace(next_state, next_state + delta_x,
                                 num_steps) + rand_x_init
            # action: copy over to number of steps
            u_init_i = np.random.uniform(low=[-4.],
                                         high=[4],
                                         size=(num_steps, 1))
            u_init = u_init_i
            cost_i = (num_steps - 1) * step_sz
            #u_init = np.repeat(u_init_i, num_steps, axis=0).reshape(-1,len(u_init_i))
            #u_init = u_init + np.random.normal(scale=1., size=u_init.shape)
            t_init = np.linspace(0, cost_i, num_steps)
        return x_init, u_init, t_init

    # load data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(args.seen_N,
                                                       args.seen_NP,
                                                       args.data_folder, obs_f,
                                                       args.seen_s,
                                                       args.seen_sp)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            args.unseen_N, args.unseen_NP, args.data_folder, obs_f,
            args.unseen_s, args.unseen_sp)
    # test
    # testing

    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.
    T = 1
    for _ in range(T):
        # unnormalize function
        normalize_func = lambda x: normalize(x, args.world_size)
        unnormalize_func = lambda x: unnormalize(x, args.world_size)
        # seen
        if args.seen_N > 0:
            time_file = os.path.join(
                args.model_path,
                'time_seen_epoch_%d_mlp.p' % (args.start_epoch))
            fes_path_, valid_path_ = eval_tasks(
                mpNet0, mpNet1, seen_test_data, args.model_path, time_file,
                IsInCollision, normalize_func, unnormalize_func, informer,
                init_informer, system, dynamics, xdot, jax_dynamics,
                enforce_bounds, traj_opt, step_sz, num_steps)
            valid_path = valid_path_.flatten()
            fes_path = fes_path_.flatten(
            )  # notice different environments are involved
            seen_test_suc_rate += fes_path.sum() / valid_path.sum()
        # unseen
        if args.unseen_N > 0:
            time_file = os.path.join(
                args.model_path,
                'time_unseen_epoch_%d_mlp.p' % (args.start_epoch))
            fes_path_, valid_path_ = eval_tasks(
                mpNet0, mpNet1, unseen_test_data, args.model_path, time_file,
                IsInCollision, normalize_func, unnormalize_func, informer,
                init_informer, system, dynamics, xdot, jax_dynamics,
                enforce_bounds, traj_opt, step_sz, num_steps)
            valid_path = valid_path_.flatten()
            fes_path = fes_path_.flatten(
            )  # notice different environments are involved
            unseen_test_suc_rate += fes_path.sum() / valid_path.sum()
    if args.seen_N > 0:
        seen_test_suc_rate = seen_test_suc_rate / T
        f = open(
            os.path.join(args.model_path,
                         'seen_accuracy_epoch_%d.txt' % (args.start_epoch)),
            'w')
        f.write(str(seen_test_suc_rate))
        f.close()
    if args.unseen_N > 0:
        unseen_test_suc_rate = unseen_test_suc_rate / T  # Save the models
        f = open(
            os.path.join(args.model_path,
                         'unseen_accuracy_epoch_%d.txt' % (args.start_epoch)),
            'w')
        f.write(str(unseen_test_suc_rate))
        f.close()
Пример #5
0
def main(args):
    #global hl
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
    # environment setting
    cae = cae_identity
    mlp = MLP
    cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        system = standard_cpp_systems.PSOPTPendulum()
        dynamics = None
        enforce_bounds = None
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole':
        normalize = cart_pole.normalize
        unnormalize = cart_pole.unnormalize
        dynamics = cartpole.dynamics
        system = _sst_module.CartPole()
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.CartPole()
        dynamics = cartpole.dynamics
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'acrobot_obs':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_2':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP2
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_3':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_4':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_5':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_6':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP4
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_7':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP5
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_8':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20

    mpnet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp)
    # load net
    # load previously trained model if start epoch > 0
    model_dir = args.model_dir
    model_dir = model_dir + 'cost_' + args.env_type + "_lr%f_%s_step_%d/" % (
        args.learning_rate, args.opt, args.num_steps)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = 'cost_kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
        args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet, os.path.join(model_dir, model_path))
        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(model_dir, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet.cuda()
        mpnet.mlp.cuda()
        mpnet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)
    if args.start_epoch > 0:
        #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
        load_opt_state(mpnet, os.path.join(model_dir, model_path))
    mpnet.eval()
    # load train and test data
    print('loading...')
    obs, cost_dataset, cost_targets, env_indices, \
    _, _, _, _ = data_loader.load_train_dataset_cost(N=args.no_env, NP=args.no_motion_paths,
                                                data_folder=args.path_folder, obs_f=True,
                                                direction=args.direction,
                                                dynamics=dynamics, enforce_bounds=enforce_bounds,
                                                system=system, step_sz=step_sz, num_steps=args.num_steps)
    # randomize the dataset before training
    data = list(zip(cost_dataset, cost_targets, env_indices))
    random.shuffle(data)
    dataset, targets, env_indices = list(zip(*data))
    dataset = list(dataset)
    targets = list(targets)
    env_indices = list(env_indices)
    dataset = np.array(dataset)
    targets = np.array(targets)
    env_indices = np.array(env_indices)

    val_i = 0
    for i in range(0, len(dataset), args.batch_size):
        # validation
        # calculate the corresponding batch in val_dataset
        dataset_i = dataset[i:i + args.batch_size]
        targets_i = targets[i:i + args.batch_size]
        env_indices_i = env_indices[i:i + args.batch_size]
        # record
        bi = dataset_i.astype(np.float32)
        print('bi shape:')
        print(bi.shape)
        bt = targets_i
        bi = torch.FloatTensor(bi)
        bt = torch.FloatTensor(bt)
        bi = normalize(bi, args.world_size)
        bi = to_var(bi)
        bt = to_var(bt)
        if obs is None:
            bobs = None
        else:
            bobs = obs[env_indices_i].astype(np.float32)
            bobs = torch.FloatTensor(bobs)
            bobs = to_var(bobs)
        print('cost network output: ')
        print(mpnet(bi, bobs).cpu().data)
        print('target: ')
        print(bt.cpu().data)
Пример #6
0
def main(args):
    # load MPNet
    #global hl
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)

    if args.debug:
        from sparse_rrt import _sst_module
        from plan_utility import cart_pole, cart_pole_obs, pendulum, acrobot_obs
        from tools import data_loader

        cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        if args.debug:
            normalize = pendulum.normalize
            unnormalize = pendulum.unnormalize
            system = standard_cpp_systems.PSOPTPendulum()
            dynamics = None
            enforce_bounds = None
            step_sz = 0.002
            num_steps = 20

    elif args.env_type == 'cartpole':
        if args.debug:
            normalize = cart_pole.normalize
            unnormalize = cart_pole.unnormalize
            dynamics = cartpole.dynamics
            system = _sst_module.CartPole()
            enforce_bounds = cartpole.enforce_bounds
            step_sz = 0.002
            num_steps = 20
    elif args.env_type == 'cartpole_obs':
        if args.debug:
            normalize = cart_pole_obs.normalize
            unnormalize = cart_pole_obs.unnormalize
            system = _sst_module.CartPole()
            dynamics = cartpole.dynamics
            enforce_bounds = cartpole.enforce_bounds
            step_sz = 0.002
            num_steps = 20
    elif args.env_type == 'acrobot_obs':
        if args.debug:
            normalize = acrobot_obs.normalize
            unnormalize = acrobot_obs.unnormalize
            system = _sst_module.PSOPTAcrobot()
            #dynamics = acrobot_obs.dynamics
            dynamics = lambda x, u, t: cpp_propagator.propagate(
                system, x, u, t)
            enforce_bounds = acrobot_obs.enforce_bounds
            step_sz = 0.02
            num_steps = 20
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d

    elif args.env_type == 'acrobot_obs_8':
        if args.debug:
            normalize = acrobot_obs.normalize
            unnormalize = acrobot_obs.unnormalize
            system = _sst_module.PSOPTAcrobot()
            #dynamics = acrobot_obs.dynamics
            dynamics = lambda x, u, t: cpp_propagator.propagate(
                system, x, u, t)
            enforce_bounds = acrobot_obs.enforce_bounds
            step_sz = 0.02
            num_steps = 20
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3

    mpnet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp)
    # load net
    # load previously trained model if start epoch > 0
    model_dir = args.model_dir
    model_dir = model_dir + 'cost_' + args.env_type + "_lr%f_%s_step_%d/" % (
        args.learning_rate, args.opt, args.num_steps)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = 'cost_kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
        args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet, os.path.join(model_dir, model_path))
        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(model_dir, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet.cuda()
        mpnet.mlp.cuda()
        mpnet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)
    if args.start_epoch > 0:
        #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
        load_opt_state(mpnet, os.path.join(model_dir, model_path))

    # load train and test data
    print('loading...')
    if args.debug:
        obs, cost_dataset, cost_targets, env_indices, \
        _, _, _, _ = data_loader.load_train_dataset_cost(N=args.no_env, NP=args.no_motion_paths,
                                                    data_folder=args.path_folder, obs_f=True,
                                                    direction=args.direction,
                                                    dynamics=dynamics, enforce_bounds=enforce_bounds,
                                                    system=system, step_sz=step_sz, num_steps=args.num_steps)
        # randomize the dataset before training
        data = list(zip(cost_dataset, cost_targets, env_indices))
        random.shuffle(data)
        dataset, targets, env_indices = list(zip(*data))
        dataset = list(dataset)
        targets = list(targets)
        env_indices = list(env_indices)
        dataset = np.array(dataset)
        targets = np.array(targets)
        env_indices = np.array(env_indices)
        # record
        bi = dataset.astype(np.float32)
        print('bi shape:')
        print(bi.shape)
        bt = targets
        bi = torch.FloatTensor(bi)
        bt = torch.FloatTensor(bt)
        bi = normalize(bi, args.world_size)
        bi = to_var(bi)
        bt = to_var(bt)
        if obs is None:
            bobs = None
        else:
            bobs = obs[env_indices].astype(np.float32)
            bobs = torch.FloatTensor(bobs)
            bobs = to_var(bobs)
    else:
        bobs = np.random.rand(1, 1, args.AE_input_size, args.AE_input_size)
        bobs = torch.from_numpy(bobs).type(torch.FloatTensor)
        bobs = to_var(bobs)
        bi = np.random.rand(1, args.total_input_size)
        bt = np.random.rand(1, args.output_size)
        bi = torch.from_numpy(bi).type(torch.FloatTensor)
        bt = torch.from_numpy(bt).type(torch.FloatTensor)
        bi = to_var(bi)
        bt = to_var(bt)
    # set to training model to enable dropout
    #mpnet.train()
    mpnet.eval()

    MLP = mpnet.mlp
    encoder = mpnet.encoder
    traced_encoder = torch.jit.trace(encoder, (bobs))
    encoder_output = encoder(bobs)
    mlp_input = torch.cat((encoder_output, bi), 1)
    traced_MLP = torch.jit.trace(MLP, (mlp_input))
    traced_encoder.save("costnet_%s_encoder_epoch_%d_step_%d.pt" %
                        (args.env_type, args.start_epoch, args.num_steps))
    traced_MLP.save("costnet_%s_MLP_epoch_%d_step_%d.pt" %
                    (args.env_type, args.start_epoch, args.num_steps))

    # test the traced model
    serilized_encoder = torch.jit.script(encoder)
    serilized_MLP = torch.jit.script(MLP)
    serilized_encoder_output = serilized_encoder(bobs)
    serilized_MLP_input = torch.cat((serilized_encoder_output, bi), 1)
    serilized_MLP_output = serilized_MLP(serilized_MLP_input)
    print('encoder output: ', serilized_encoder_output)
    print('MLP output: ', serilized_MLP_output)
    print('data: ', bt)
def main(args):
    #global hl
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
    # environment setting
    cae = cae_identity
    mlp = MLP
    cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        system = standard_cpp_systems.PSOPTPendulum()
        dynamics = None
        enforce_bounds = None
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole':
        normalize = cart_pole.normalize
        unnormalize = cart_pole.unnormalize
        dynamics = cartpole.dynamics
        system = _sst_module.CartPole()
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = [0, 2]
        vel_indices = [1, 3]
    elif args.env_type == 'cartpole_obs_2':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP2
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = [0, 2]
        vel_indices = [1, 3]

    elif args.env_type == 'cartpole_obs_3':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP4
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = [0, 2]
        vel_indices = [1, 3]
        
    elif args.env_type == 'cartpole_obs_4_small':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])
    elif args.env_type == 'cartpole_obs_4_big':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])
    elif args.env_type == 'cartpole_obs_4_small_x_theta':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 1])
        vel_indices = np.array([2, 3])
    elif args.env_type == 'cartpole_obs_4_big_x_theta':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 1])
        vel_indices = np.array([2, 3])
    elif args.env_type == 'cartpole_obs_4_small_decouple_output':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])
    elif args.env_type == 'cartpole_obs_4_big_decouple_output':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3
        cae = CAE_cartpole_voxel_2d
        
        # dynamics: None    -- without integration to dense trajectory
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        pos_indices = np.array([0, 2])
        vel_indices = np.array([1, 3])

        
        
    elif args.env_type == 'acrobot_obs':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_2':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP2
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_3':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_4':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_5':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_6':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP4
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
        pos_indices = [0, 1]
        vel_indices = [2, 3]

    elif args.env_type == 'acrobot_obs_7':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP5
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_8':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20



    # set loss for mpnet
    if args.loss == 'mse':
        #mpnet.loss_f = nn.MSELoss()
        def mse_loss(y1, y2):
            l = (y1 - y2) ** 2
            l = torch.mean(l, dim=0)  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l
        loss_f_p = mse_loss
        loss_f_v = mse_loss

    elif args.loss == 'l1_smooth':
        #mpnet.loss_f = nn.SmoothL1Loss()
        def l1_smooth_loss(y1, y2):
            l1 = torch.abs(y1 - y2)
            cond = l1 < 1
            l = torch.where(cond, 0.5 * l1 ** 2, l1)
            l = torch.mean(l, dim=0)  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l
        loss_f_p = l1_smooth_loss
        loss_f_v = l1_smooth_loss

    elif args.loss == 'mse_decoupled':
        def mse_decoupled(y1, y2):
            # for angle terms, wrap it to -pi~pi
            l_0 = torch.abs(y1[:,0] - y2[:,0]) ** 2
            l_1 = torch.abs(y1[:,1] - y2[:,1]) ** 2
            l_2 = torch.abs(y1[:,2] - y2[:,2]) # angular dimension
            l_3 = torch.abs(y1[:,3] - y2[:,3]) ** 2
            cond = (l_2 > 1.0) * (l_2 <= 2.0)
            l_2 = torch.where(cond, 2*1.0-l_2, l_2)
            l_2 = l_2 ** 2
            l_0 = torch.mean(l_0)
            l_1 = torch.mean(l_1)
            l_2 = torch.mean(l_2)
            l_3 = torch.mean(l_3)
            return torch.stack([l_0, l_1, l_2, l_3])
        loss_f_p = mse_decoupled
        loss_f_v = mse_decoupled

    elif args.loss == 'l1_smooth_decoupled':
        
        # this only is for cartpole, need to adapt to other systems
        #TODO
        def l1_smooth_decoupled(y1, y2):
            # for angle terms, wrap it to -pi~pi
            l_0 = torch.abs(y1[:,0] - y2[:,0])
            l_1 = torch.abs(y1[:,1] - y2[:,1]) # angular dimension
            cond = (l_1 > 1.0) * (l_1 <= 2.0)
            l_1 = torch.where(cond, 2*1.0-l_1, l_1)
            
            # then change to l1_smooth_loss
            cond = l_0 < 1
            l_0 = torch.where(cond, 0.5 * l_0 ** 2, l_0)
            cond = l_1 < 1
            l_1 = torch.where(cond, 0.5 * l_1 ** 2, l_1)
            
            l_0 = torch.mean(l_0)
            l_1 = torch.mean(l_1)
            return torch.stack([l_0, l_1])
        def l1_smooth_loss(y1, y2):
            l1 = torch.abs(y1 - y2)
            cond = l1 < 1
            l = torch.where(cond, 0.5 * l1 ** 2, l1)
            l = torch.mean(l, dim=0)  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l
        loss_f_p = l1_smooth_decoupled
        loss_f_v = l1_smooth_loss


    if 'decouple_output' in args.env_type:
        print('mpnet using decoupled output')
        mpnet_pnet = KMPNet(args.total_input_size, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_p)
        mpnet_vnet = KMPNet(args.total_input_size, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_v)
    else:
        mpnet_pnet = KMPNet(args.total_input_size//2, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_p)
        mpnet_vnet = KMPNet(args.total_input_size//2, args.AE_input_size, args.mlp_input_size, args.output_size//2,
                       cae, mlp, loss_f_v)
        
    # load net
    # load previously trained model if start epoch > 0

    model_dir = args.model_dir
    if args.loss == 'mse':
        if args.multigoal == 0:
            model_dir = model_dir+args.env_type+"_lr%f_%s_step_%d/" % (args.learning_rate, args.opt, args.num_steps)
        else:
            model_dir = model_dir+args.env_type+"_lr%f_%s_step_%d_multigoal/" % (args.learning_rate, args.opt, args.num_steps)
    else:
        if args.multigoal == 0:
            model_dir = model_dir+args.env_type+"_lr%f_%s_loss_%s_step_%d/" % (args.learning_rate, args.opt, args.loss, args.num_steps)
        else:
            model_dir = model_dir+args.env_type+"_lr%f_%s_loss_%s_step_%d_multigoal/" % (args.learning_rate, args.opt, args.loss, args.num_steps)


    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_pnet_path='kmpnet_pnet_epoch_%d_direction_%d_step_%d.pkl' %(args.start_epoch, args.direction, args.num_steps)
    model_vnet_path='kmpnet_vnet_epoch_%d_direction_%d_step_%d.pkl' %(args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet_pnet, os.path.join(model_dir, model_pnet_path))
        load_net_state(mpnet_vnet, os.path.join(model_dir, model_vnet_path))

        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(os.path.join(model_dir, model_pnet_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet_pnet.cuda()
        mpnet_pnet.mlp.cuda()
        mpnet_pnet.encoder.cuda()

        mpnet_vnet.cuda()
        mpnet_vnet.mlp.cuda()
        mpnet_vnet.encoder.cuda()

        if args.opt == 'Adagrad':
            mpnet_pnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet_pnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet_pnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet_pnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)

            
        if args.opt == 'Adagrad':
            mpnet_vnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet_vnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet_vnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet_vnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)

            
            
        if args.start_epoch > 0:
            #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
            load_opt_state(mpnet_pnet, os.path.join(model_dir, model_path))
            load_opt_state(mpnet_vnet, os.path.join(model_dir, model_path))


    # load train and test data
    print('loading...')
    obs, waypoint_dataset, waypoint_targets, env_indices, \
    _, _, _, _ = data_loader.load_train_dataset(N=args.no_env, NP=args.no_motion_paths,
                                                data_folder=args.path_folder, obs_f=True,
                                                direction=args.direction,
                                                dynamics=dynamics, enforce_bounds=enforce_bounds,
                                                system=system, step_sz=step_sz,
                                                num_steps=args.num_steps, multigoal=args.multigoal)
    # randomize the dataset before training
    data=list(zip(waypoint_dataset,waypoint_targets,env_indices))
    random.shuffle(data)
    dataset,targets,env_indices=list(zip(*data))
    dataset = list(dataset)
    dataset = np.array(dataset)
    targets = np.array(targets)
    print(np.concatenate([pos_indices, pos_indices+args.total_input_size//2]))
    p_dataset = dataset[:, np.concatenate([pos_indices, pos_indices+args.total_input_size//2])]
    v_dataset = dataset[:, np.concatenate([vel_indices, vel_indices+args.total_input_size//2])]
    if 'decouple_output' in args.env_type:
        # only decouple output
        print('only decouple output but not input')
        p_dataset = dataset
        v_dataset = dataset
    print(p_dataset.shape)
    print(v_dataset.shape)
    
    
    
    p_targets = targets[:,pos_indices]
    v_targets = targets[:,vel_indices]   # this is only for cartpole
                                # TODO: add string for choosing env

    p_targets = list(p_targets)
    v_targets = list(v_targets)
    #targets = list(targets)
    env_indices = list(env_indices)
    dataset = np.array(dataset)
    #targets = np.array(targets)
    env_indices = np.array(env_indices)

    # use 5% as validation dataset
    val_len = int(len(dataset) * 0.05)
    val_p_dataset = p_dataset[-val_len:]
    val_v_dataset = v_dataset[-val_len:]
    val_p_targets = p_targets[-val_len:]
    val_v_targets = v_targets[-val_len:]
    val_env_indices = env_indices[-val_len:]

    p_dataset = p_dataset[:-val_len]
    v_dataset = v_dataset[:-val_len]
    p_targets = p_targets[:-val_len]
    v_targets = v_targets[:-val_len]
    env_indices = env_indices[:-val_len]

    # Train the Models
    print('training...')
    if args.loss == 'mse':
        if args.multigoal == 0:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, )
        else:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d_multigoal' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, )
    else:
        if args.multigoal == 0:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d_loss_%s' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, args.loss, )
        else:
            writer_fname = 'pos_vel_%s_%f_%s_direction_%d_step_%d_loss_%s_multigoal' % (args.env_type, args.learning_rate, args.opt, args.direction, args.num_steps, args.loss, )


    writer = SummaryWriter('./runs/'+writer_fname)
    record_i = 0
    val_record_i = 0
    p_loss_avg_i = 0
    p_val_loss_avg_i = 0
    p_loss_avg = 0.
    p_val_loss_avg = 0.
    v_loss_avg_i = 0
    v_val_loss_avg_i = 0
    v_loss_avg = 0.
    v_val_loss_avg = 0.

    loss_steps = 100  # record every 100 loss
    
    
    world_size = np.array(args.world_size)
    pos_world_size = list(world_size[pos_indices])
    vel_world_size = list(world_size[vel_indices])
    

    
    for epoch in range(args.start_epoch+1,args.num_epochs+1):
        print('epoch' + str(epoch))
        val_i = 0
        for i in range(0,len(p_dataset),args.batch_size):
            print('epoch: %d, training... path: %d' % (epoch, i+1))
            p_dataset_i = p_dataset[i:i+args.batch_size]
            v_dataset_i = v_dataset[i:i+args.batch_size]
            p_targets_i = p_targets[i:i+args.batch_size]
            v_targets_i = v_targets[i:i+args.batch_size]
            env_indices_i = env_indices[i:i+args.batch_size]
            # record
            p_bi = p_dataset_i.astype(np.float32)
            v_bi = v_dataset_i.astype(np.float32)
            print('p_bi shape:')
            print(p_bi.shape)
            print('v_bi shape:')
            print(v_bi.shape)
            p_bt = p_targets_i
            v_bt = v_targets_i
            p_bi = torch.FloatTensor(p_bi)
            v_bi = torch.FloatTensor(v_bi)
            p_bt = torch.FloatTensor(p_bt)
            v_bt = torch.FloatTensor(v_bt)

            # edit: disable this for investigation of the good weights for training, and for wrapping
            if 'decouple_output' in args.env_type:
                print('using normalizatino of decoupled output')
                # only decouple output but not input
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, args.world_size), normalize(v_bi, args.world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)
            else:
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, pos_world_size), normalize(v_bi, vel_world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)


            mpnet_pnet.zero_grad()
            mpnet_vnet.zero_grad()

            p_bi=to_var(p_bi)
            v_bi=to_var(v_bi)
            p_bt=to_var(p_bt)
            v_bt=to_var(v_bt)

            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            print('-------pnet-------')
            print('before training losses:')
            print(mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt))
            mpnet_pnet.step(p_bi, bobs, p_bt)
            print('after training losses:')
            print(mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt))
            p_loss = mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt)
            #update_line(hl, ax, [i//args.batch_size, loss.data.numpy()])
            p_loss_avg += p_loss.cpu().data
            p_loss_avg_i += 1
            
            print('-------vnet-------')
            print('before training losses:')
            print(mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt))
            mpnet_vnet.step(v_bi, bobs, v_bt)
            print('after training losses:')
            print(mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt))
            v_loss = mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt)
            #update_line(hl, ax, [i//args.batch_size, loss.data.numpy()])
            v_loss_avg += v_loss.cpu().data
            v_loss_avg_i += 1
            

            if p_loss_avg_i >= loss_steps:
                p_loss_avg = p_loss_avg / p_loss_avg_i
                writer.add_scalar('p_train_loss_0', p_loss_avg[0], record_i)
                writer.add_scalar('p_train_loss_1', p_loss_avg[1], record_i)

                v_loss_avg = v_loss_avg / v_loss_avg_i
                writer.add_scalar('v_train_loss_0', v_loss_avg[0], record_i)
                writer.add_scalar('v_train_loss_1', v_loss_avg[1], record_i)

                record_i += 1
                p_loss_avg = 0.
                p_loss_avg_i = 0

                v_loss_avg = 0.
                v_loss_avg_i = 0

                
            # validation
            # calculate the corresponding batch in val_dataset
            p_dataset_i = val_p_dataset[val_i:val_i+args.batch_size]
            v_dataset_i = val_v_dataset[val_i:val_i+args.batch_size]

            p_targets_i = val_p_targets[val_i:val_i+args.batch_size]
            v_targets_i = val_v_targets[val_i:val_i+args.batch_size]

            env_indices_i = val_env_indices[val_i:val_i+args.batch_size]
            val_i = val_i + args.batch_size
            if val_i > val_len:
                val_i = 0
            # record
            p_bi = p_dataset_i.astype(np.float32)
            v_bi = v_dataset_i.astype(np.float32)

            print('p_bi shape:')
            print(p_bi.shape)
            print('v_bi shape:')
            print(v_bi.shape)

            p_bt = p_targets_i
            v_bt = v_targets_i
            p_bi = torch.FloatTensor(p_bi)
            v_bi = torch.FloatTensor(v_bi)

            p_bt = torch.FloatTensor(p_bt)
            v_bt = torch.FloatTensor(v_bt)
            if 'decouple_output' in args.env_type:
                # only decouple output but not input
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, args.world_size), normalize(v_bi, args.world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)
            else:
                p_bi, v_bi, p_bt, v_bt = normalize(p_bi, pos_world_size), normalize(v_bi, vel_world_size), normalize(p_bt, pos_world_size), normalize(v_bt, vel_world_size)
                
            p_bi=to_var(p_bi)
            v_bi=to_var(v_bi)
            p_bt=to_var(p_bt)
            v_bt=to_var(v_bt)

            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            print('-------pnet loss--------')
            p_loss = mpnet_pnet.loss(mpnet_pnet(p_bi, bobs), p_bt)
            print('validation loss: ' % (p_loss.cpu().data))

            p_val_loss_avg += p_loss.cpu().data
            p_val_loss_avg_i += 1

            print('-------vnet loss--------')
            v_loss = mpnet_vnet.loss(mpnet_vnet(v_bi, bobs), v_bt)
            print('validation loss: ' % (v_loss.cpu().data))

            v_val_loss_avg += v_loss.cpu().data
            v_val_loss_avg_i += 1

            
            if p_val_loss_avg_i >= loss_steps:
                p_val_loss_avg = p_val_loss_avg / p_val_loss_avg_i
                writer.add_scalar('p_val_loss_0', p_val_loss_avg[0], val_record_i)
                writer.add_scalar('p_val_loss_1', p_val_loss_avg[1], val_record_i)
                v_val_loss_avg = v_val_loss_avg / v_val_loss_avg_i
                writer.add_scalar('v_val_loss_0', v_val_loss_avg[0], val_record_i)
                writer.add_scalar('v_val_loss_1', v_val_loss_avg[1], val_record_i)

                
                val_record_i += 1
                p_val_loss_avg = 0.
                p_val_loss_avg_i = 0
                
                v_val_loss_avg = 0.
                v_val_loss_avg_i = 0

        # Save the models
        if epoch > 0 and epoch % 50 == 0:
            model_pnet_path='kmpnet_pnet_epoch_%d_direction_%d_step_%d.pkl' %(epoch, args.direction, args.num_steps)
            model_vnet_path='kmpnet_vnet_epoch_%d_direction_%d_step_%d.pkl' %(epoch, args.direction, args.num_steps)
            #save_state(mpnet, torch_seed, np_seed, py_seed, os.path.join(args.model_path,model_path))
            save_state(mpnet_pnet, torch_seed, np_seed, py_seed, os.path.join(model_dir,model_pnet_path))
            save_state(mpnet_vnet, torch_seed, np_seed, py_seed, os.path.join(model_dir,model_vnet_path))

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Пример #8
0
def main(args):
    # set seed
    print(args.model_path)
    torch_seed = np.random.randint(low=0, high=1000)
    np_seed = np.random.randint(low=0, high=1000)
    py_seed = np.random.randint(low=0, high=1000)
    torch.manual_seed(torch_seed)
    np.random.seed(np_seed)
    random.seed(py_seed)
    # Build the models
    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)

    # setup evaluation function and load function
    if args.env_type == 'pendulum':
        IsInCollision = pendulum.IsInCollision
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        obs_file = None
        obc_file = None
        cae = cae_identity
        mlp = MLP
        system = standard_cpp_systems.PSOPTPendulum()
        bvp_solver = _sst_module.PSOPTBVPWrapper(system, 2, 1, 0)
        max_iter = 200
        min_time_steps = 20
        max_time_steps = 200
        integration_step = 0.002
        goal_radius = 0.1
        random_seed = 0
        sst_delta_near = 0.05
        sst_delta_drain = 0.02
        vel_idx = [1]

    mpNet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp)
    # load previously trained model if start epoch > 0
    model_path = 'kmpnet_epoch_%d.pkl' % (args.start_epoch)
    if args.start_epoch > 0:
        load_net_state(mpNet, os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(args.model_path, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)
    if torch.cuda.is_available():
        mpNet.cuda()
        mpNet.mlp.cuda()
        mpNet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpNet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpNet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpNet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
    if args.start_epoch > 0:
        load_opt_state(mpNet, os.path.join(args.model_path, model_path))

    # load train and test data
    print('loading...')
    if args.seen_N > 0:
        seen_test_data = data_loader.load_test_dataset(
            N=args.seen_N,
            NP=args.seen_NP,
            s=args.seen_s,
            sp=args.seen_sp,
            p_folder=args.path_folder,
            obs_f=obs_file,
            obc_f=obc_file)
    if args.unseen_N > 0:
        unseen_test_data = data_loader.load_test_dataset(
            N=args.unseen_N,
            NP=args.unseen_NP,
            s=args.unseen_s,
            sp=args.unseen_sp,
            p_folder=args.path_folder,
            obs_f=obs_file,
            obc_f=obc_file)
    # test
    # testing
    print('testing...')
    seen_test_suc_rate = 0.
    unseen_test_suc_rate = 0.
    T = 1

    obc, obs, paths, path_lengths = seen_test_data
    if obs is not None:
        obs = obs.astype(np.float32)
        obs = torch.from_numpy(obs)
    fes_env = []  # list of list
    valid_env = []
    time_env = []
    time_total = []
    normalize_func = lambda x: normalize(x, args.world_size)
    unnormalize_func = lambda x: unnormalize(x, args.world_size)

    low = []
    high = []
    state_bounds = system.get_state_bounds()
    for i in range(len(state_bounds)):
        low.append(state_bounds[i][0])
        high.append(state_bounds[i][1])

    for i in range(len(paths)):
        time_path = []
        fes_path = []  # 1 for feasible, 0 for not feasible
        valid_path = []  # if the feasibility is valid or not
        # save paths to different files, indicated by i
        # feasible paths for each env
        suc_n = 0
        sst_suc_n = 0
        for j in range(len(paths[0])):
            time0 = time.time()
            time_norm = 0.
            fp = 0  # indicator for feasibility
            print("step: i=" + str(i) + " j=" + str(j))
            p1_ind = 0
            p2_ind = 0
            p_ind = 0
            if path_lengths[i][j] == 0:
                # invalid, feasible = 0, and path count = 0
                fp = 0
                valid_path.append(0)
            if path_lengths[i][j] > 0:
                fp = 0
                valid_path.append(1)
                path = [paths[i][j][0], paths[i][j][path_lengths[i][j] - 1]]
                start = paths[i][j][0]
                end = paths[i][j][path_lengths[i][j] - 1]
                start[1] = 0.
                end[1] = 0.
                # plot the entire path
                #plt.plot(paths[i][j][:,0], paths[i][j][:,1])
                """
                planner = SST(
                    state_bounds=system.get_state_bounds(),
                    control_bounds=system.get_control_bounds(),
                    distance=system.distance_computer(),
                    start_state=start,
                    goal_state=end,
                    goal_radius=goal_radius,
                    random_seed=0,
                    sst_delta_near=sst_delta_near,
                    sst_delta_drain=sst_delta_drain
                )
                """
                planner = RRT(state_bounds=system.get_state_bounds(),
                              control_bounds=system.get_control_bounds(),
                              distance=system.distance_computer(),
                              start_state=start,
                              goal_state=end,
                              goal_radius=goal_radius,
                              random_seed=0)

                control = []
                time_step = []
                MAX_NEURAL_REPLAN = 11
                if obs is None:
                    obs_i = None
                    obc_i = None
                else:
                    obs_i = obs[i]
                    obc_i = obc[i]
            print('created RRT')
            # Run planning and print out solution is some statistics every few iterations.
            time0 = time.time()
            start = paths[i][j][0]
            #end = paths[i][j][path_lengths[i][j]-1]
            new_sample = start
            sample = start
            N_sample = 10
            for iteration in range(max_iter // N_sample):
                #if iteration % 50 == 0:
                #    # from time to time use the goal
                #    sample = end
                #    #planner.step_with_sample(system, sample, 20, 200, 0.002)
                #else:
                #planner.step(system, min_time_steps, max_time_steps, integration_step)
                #sample = np.random.uniform(low=low, high=high)
                for num_sample in range(N_sample):
                    ip1 = np.concatenate([new_sample, end])
                    np.expand_dims(ip1, 0)
                    #ip1=torch.cat((obs,start,goal)).unsqueeze(0)
                    time0 = time.time()
                    ip1 = normalize_func(ip1)
                    ip1 = torch.FloatTensor(ip1)
                    time_norm += time.time() - time0
                    ip1 = to_var(ip1)
                    if obs is not None:
                        obs = torch.FloatTensor(obs).unsqueeze(0)
                        obs = to_var(obs)
                    sample = mpNet(ip1, obs).squeeze(0)
                    # unnormalize to world size
                    sample = sample.data.cpu().numpy()
                    time0 = time.time()
                    sample = unnormalize_func(sample)
                    print('sample:')
                    print(sample)
                    print('start:')
                    print(start)
                    print('goal:')
                    print(end)
                    print('accuracy: %f' % (float(suc_n) / (j + 1)))
                    print('sst accuracy: %f' % (float(sst_suc_n) / (j + 1)))
                    sample = planner.step_with_sample(system, sample,
                                                      min_time_steps,
                                                      max_time_steps, 0.002)
                    #planner.step_bvp(system, 10, 200, 0.002)
                    im = planner.visualize_nodes(system)
                    show_image(im, 'nodes', wait=False)
                new_sample = planner.step_with_sample(system, end,
                                                      min_time_steps,
                                                      max_time_steps, 0.002)

                solution = planner.get_solution()
                if solution is not None:
                    print('solved.')
                    suc_n += 1
                    break

            planner = SST(state_bounds=system.get_state_bounds(),
                          control_bounds=system.get_control_bounds(),
                          distance=system.distance_computer(),
                          start_state=start,
                          goal_state=end,
                          goal_radius=goal_radius,
                          random_seed=0,
                          sst_delta_near=sst_delta_near,
                          sst_delta_drain=sst_delta_drain)

            # Run planning and print out solution is some statistics every few iterations.
            time0 = time.time()
            start = paths[i][j][0]
            #end = paths[i][j][path_lengths[i][j]-1]
            new_sample = start
            sample = start
            N_sample = 10
            for iteration in range(max_iter // N_sample):
                for k in range(N_sample):

                    sample = np.random.uniform(low=low, high=high)
                    planner.step_with_sample(system, sample, min_time_steps,
                                             max_time_steps, integration_step)
                    im = planner.visualize_tree(system)
                    show_image(im, 'tree', wait=False)
                    print('accuracy: %f' % (float(suc_n) / (j + 1)))
                    print('sst accuracy: %f' % (float(sst_suc_n) / (j + 1)))

                planner.step_with_sample(system, end, min_time_steps,
                                         max_time_steps, integration_step)
                solution = planner.get_solution()
                if solution is not None:
                    print('solved.')
                    sst_suc_n += 1
                    break

            print('accuracy: %f' % (float(suc_n) / (j + 1)))
            print('sst accuracy: %f' % (float(sst_suc_n) / (j + 1)))
Пример #9
0
def main(args):
    #global hl

    if torch.cuda.is_available():
        torch.cuda.set_device(args.device)
    # environment setting
    multigoal = False
    cpp_propagator = _sst_module.SystemPropagator()
    if args.env_type == 'pendulum':
        normalize = pendulum.normalize
        unnormalize = pendulum.unnormalize
        system = standard_cpp_systems.PSOPTPendulum()
        dynamics = None
        enforce_bounds = None
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'cartpole':
        normalize = cart_pole.normalize
        unnormalize = cart_pole.unnormalize
        dynamics = cartpole.dynamics
        system = _sst_module.CartPole()
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.CartPole()
        dynamics = cartpole.dynamics
        enforce_bounds = cartpole.enforce_bounds
        step_sz = 0.002
        num_steps = 20
        cae = cae_identity
        mlp = MLP
    elif args.env_type == 'cartpole_obs_4':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3_no_dropout
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        multigoal = False
        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20
    elif args.env_type == 'cartpole_obs_4_multigoal':
        normalize = cart_pole_obs.normalize
        unnormalize = cart_pole_obs.unnormalize
        system = _sst_module.PSOPTCartPole()
        mlp = mlp_cartpole.MLP3_no_dropout
        cae = CAE_cartpole_voxel_2d
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        #dynamics = None
        multigoal = True

        enforce_bounds = cart_pole_obs.enforce_bounds
        step_sz = 0.002
        num_steps = 20

    elif args.env_type == 'acrobot_obs':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_2':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP2
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_3':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_2
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_4':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP3
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_5':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_6':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP4
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_7':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP5
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20
    elif args.env_type == 'acrobot_obs_8':
        normalize = acrobot_obs.normalize
        unnormalize = acrobot_obs.unnormalize
        system = _sst_module.PSOPTAcrobot()
        mlp = mlp_acrobot.MLP6
        cae = CAE_acrobot_voxel_2d_3
        #dynamics = acrobot_obs.dynamics
        dynamics = lambda x, u, t: cpp_propagator.propagate(system, x, u, t)
        enforce_bounds = acrobot_obs.enforce_bounds
        step_sz = 0.02
        num_steps = 20

    # set loss for mpnet
    if args.loss == 'mse':
        #mpnet.loss_f = nn.MSELoss()
        def mse_loss(y1, y2):
            l = (y1 - y2)**2
            l = torch.mean(
                l, dim=0
            )  # sum alone the batch dimension, now the dimension is the same as input dimension
            return l

        loss_f = mse_loss

    elif args.loss == 'l1_smooth':
        #mpnet.loss_f = nn.SmoothL1Loss()
        def l1_smooth_loss(y1, y2):
            l1 = torch.abs(y1 - y2)
            cond = l1 < 1
            l = torch.where(cond, 0.5 * l1**2, l1)
            l = torch.mean(
                l, dim=0
            )  # sum alone the batch dimension, now the dimension is the same as input dimension

        loss_f = l1_smooth_loss

    elif args.loss == 'mse_decoupled':

        def mse_decoupled(y1, y2):
            # for angle terms, wrap it to -pi~pi
            l_0 = torch.abs(y1[:, 0] - y2[:, 0])**2
            l_1 = torch.abs(y1[:, 1] - y2[:, 1])**2
            l_2 = torch.abs(y1[:, 2] - y2[:, 2])  # angular dimension
            l_3 = torch.abs(y1[:, 3] - y2[:, 3])**2

            cond = (l_2 > 1.0) * (l_2 <= 2.0
                                  )  # np.pi after normalization is 1.0
            l_2 = torch.where(cond, 2.0 - l_2, l_2)
            l_2 = l_2**2
            l_0 = torch.mean(l_0)
            l_1 = torch.mean(l_1)
            l_2 = torch.mean(l_2)
            l_3 = torch.mean(l_3)
            return torch.stack([l_0, l_1, l_2, l_3])

        loss_f = mse_decoupled

    mpnet = KMPNet(args.total_input_size, args.AE_input_size,
                   args.mlp_input_size, args.output_size, cae, mlp, loss_f)
    # load net
    # load previously trained model if start epoch > 0
    model_dir = args.model_dir
    model_dir = model_dir + 'cost_' + args.env_type + "_lr%f_%s_step_%d/" % (
        args.learning_rate, args.opt, args.num_steps)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = 'cost_kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
        args.start_epoch, args.direction, args.num_steps)
    torch_seed, np_seed, py_seed = 0, 0, 0
    if args.start_epoch > 0:
        #load_net_state(mpnet, os.path.join(args.model_path, model_path))
        load_net_state(mpnet, os.path.join(model_dir, model_path))
        #torch_seed, np_seed, py_seed = load_seed(os.path.join(args.model_path, model_path))
        torch_seed, np_seed, py_seed = load_seed(
            os.path.join(model_dir, model_path))
        # set seed after loading
        torch.manual_seed(torch_seed)
        np.random.seed(np_seed)
        random.seed(py_seed)

    if torch.cuda.is_available():
        mpnet.cuda()
        mpnet.mlp.cuda()
        mpnet.encoder.cuda()
        if args.opt == 'Adagrad':
            mpnet.set_opt(torch.optim.Adagrad, lr=args.learning_rate)
        elif args.opt == 'Adam':
            mpnet.set_opt(torch.optim.Adam, lr=args.learning_rate)
        elif args.opt == 'SGD':
            mpnet.set_opt(torch.optim.SGD, lr=args.learning_rate, momentum=0.9)
        elif args.opt == 'ASGD':
            mpnet.set_opt(torch.optim.ASGD, lr=args.learning_rate)
    if args.start_epoch > 0:
        #load_opt_state(mpnet, os.path.join(args.model_path, model_path))
        load_opt_state(mpnet, os.path.join(model_dir, model_path))

    # load train and test data
    print('loading...')
    obs, cost_dataset, cost_targets, env_indices, \
    _, _, _, _ = data_loader.load_train_dataset_cost(N=args.no_env, NP=args.no_motion_paths,
                                                data_folder=args.path_folder, obs_f=True,
                                                direction=args.direction,
                                                dynamics=dynamics, enforce_bounds=enforce_bounds,
                                                system=system, step_sz=step_sz, num_steps=args.num_steps,
                                                multigoal=multigoal)
    # randomize the dataset before training
    data = list(zip(cost_dataset, cost_targets, env_indices))
    random.shuffle(data)
    dataset, targets, env_indices = list(zip(*data))
    dataset = list(dataset)
    targets = list(targets)
    env_indices = list(env_indices)
    dataset = np.array(dataset)
    targets = np.array(targets)
    env_indices = np.array(env_indices)

    # use 5% as validation dataset
    val_len = int(len(dataset) * 0.05)
    val_dataset = dataset[-val_len:]
    val_targets = targets[-val_len:]
    val_env_indices = env_indices[-val_len:]

    dataset = dataset[:-val_len]
    targets = targets[:-val_len]
    env_indices = env_indices[:-val_len]

    # Train the Models
    print('training...')
    writer_fname = 'cost_%s_%f_%s_direction_%d_step_%d' % (
        args.env_type, args.learning_rate, args.opt, args.direction,
        args.num_steps)
    writer = SummaryWriter('./runs/' + writer_fname)
    record_i = 0
    val_record_i = 0
    loss_avg_i = 0
    val_loss_avg_i = 0
    loss_avg = 0.
    val_loss_avg = 0.
    loss_steps = 100  # record every 100 loss
    for epoch in range(args.start_epoch + 1, args.num_epochs + 1):
        print('epoch' + str(epoch))
        val_i = 0
        for i in range(0, len(dataset), args.batch_size):
            print('epoch: %d, training... path: %d' % (epoch, i + 1))
            dataset_i = dataset[i:i + args.batch_size]
            targets_i = targets[i:i + args.batch_size]
            env_indices_i = env_indices[i:i + args.batch_size]
            # record
            bi = dataset_i.astype(np.float32)
            print('bi shape:')
            print(bi.shape)
            bt = targets_i
            bi = torch.FloatTensor(bi)
            bt = torch.FloatTensor(bt)
            bi = normalize(bi, args.world_size)
            mpnet.zero_grad()
            bi = to_var(bi)
            bt = to_var(bt)
            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            print('before training losses:')
            print(mpnet.loss(mpnet(bi, bobs), bt))
            mpnet.step(bi, bobs, bt)
            print('after training losses:')
            print(mpnet.loss(mpnet(bi, bobs), bt))
            loss = mpnet.loss(mpnet(bi, bobs), bt)
            #update_line(hl, ax, [i//args.batch_size, loss.data.numpy()])
            loss_avg += loss.cpu().data
            loss_avg_i += 1
            if loss_avg_i >= loss_steps:
                loss_avg = loss_avg / loss_avg_i
                writer.add_scalar('train_loss', loss_avg, record_i)
                record_i += 1
                loss_avg = 0.
                loss_avg_i = 0

            # validation
            # calculate the corresponding batch in val_dataset
            dataset_i = val_dataset[val_i:val_i + args.batch_size]
            targets_i = val_targets[val_i:val_i + args.batch_size]
            env_indices_i = val_env_indices[val_i:val_i + args.batch_size]
            val_i = val_i + args.batch_size
            if val_i > val_len:
                val_i = 0
            # record
            bi = dataset_i.astype(np.float32)
            print('bi shape:')
            print(bi.shape)
            bt = targets_i
            bi = torch.FloatTensor(bi)
            bt = torch.FloatTensor(bt)
            bi = normalize(bi, args.world_size)
            bi = to_var(bi)
            bt = to_var(bt)
            if obs is None:
                bobs = None
            else:
                bobs = obs[env_indices_i].astype(np.float32)
                bobs = torch.FloatTensor(bobs)
                bobs = to_var(bobs)
            loss = mpnet.loss(mpnet(bi, bobs), bt)
            print('validation loss: %f' % (loss.cpu().data))

            val_loss_avg += loss.cpu().data
            val_loss_avg_i += 1
            if val_loss_avg_i >= loss_steps:
                val_loss_avg = val_loss_avg / val_loss_avg_i
                writer.add_scalar('val_loss', val_loss_avg, val_record_i)
                val_record_i += 1
                val_loss_avg = 0.
                val_loss_avg_i = 0
        # Save the models
        if epoch > 0 and epoch % 50 == 0:
            model_path = 'cost_kmpnet_epoch_%d_direction_%d_step_%d.pkl' % (
                epoch, args.direction, args.num_steps)
            #save_state(mpnet, torch_seed, np_seed, py_seed, os.path.join(args.model_path,model_path))
            save_state(mpnet, torch_seed, np_seed, py_seed,
                       os.path.join(model_dir, model_path))
    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()