Example #1
0
def run_validation(args, sigma_obs_arr):
    use_cuda = torch.cuda.is_available() if args.use_cuda else False
    device = torch.device('cuda') if use_cuda else torch.device('cpu')
    np.random.seed(args.seed_val)
    torch.manual_seed(args.seed_val)

    dataset_folders = [
        os.path.abspath(folder) for folder in args.dataset_folders
    ]
    plan_param_file = os.path.join(dataset_folders[0],
                                   args.plan_param_file + pfiletype)
    robot_param_file = os.path.join(dataset_folders[0],
                                    args.robot_param_file + pfiletype)
    env_param_file = os.path.join(dataset_folders[0],
                                  args.env_param_file + pfiletype)

    env_data, planner_params, gp_params, obs_params,\
    optim_params, robot_data = load_params(plan_param_file, robot_param_file, env_param_file, device)

    dataset = PlanningDatasetMulti(dataset_folders,
                                   mode='train',
                                   num_envs=1000,
                                   num_env_probs=1,
                                   label_subdir='opt_trajs_gpmp2')

    # idxs = np.random.choice(len(dataset), args.num_envs, replace=False) if args.num_envs < len(dataset) else xrange(0, len(dataset))
    idxs = xrange(args.num_envs)
    print idxs
    env_params = {'x_lims': env_data['x_lims'], 'y_lims': env_data['y_lims']}
    if robot_data['type'] == 'point_robot':
        robot = PointRobot2D(robot_data['sphere_radius'],
                             use_cuda=args.use_cuda)
    batch_size = 1  #learn_params['optim']['batch_size']

    #To be used for calculating metrics later
    dt = planner_params['total_time_sec'] * 1.0 / planner_params[
        'total_time_step'] * 1.0
    gpfactor = GPFactor(planner_params['dof'], dt,
                        planner_params['total_time_step'])
    obsfactor = ObstacleFactor(planner_params['state_dim'],
                               planner_params['total_time_step'], 0.0,
                               env_params, robot)
    dof = planner_params['dof']
    use_vel_limits = planner_params[
        'use_vel_limits'] if 'use_vel_limits' in planner_params else False

    results_dict = {}
    for sigma_obs in sigma_obs_arr:
        print('Curr Sigma = ', sigma_obs)
        obs_params['cost_sigma'] = torch.tensor(sigma_obs, device=device)
        planner = DiffGPMP2Planner(gp_params,
                                   obs_params,
                                   planner_params,
                                   optim_params,
                                   env_params,
                                   robot,
                                   learn_params=None,
                                   batch_size=batch_size,
                                   use_cuda=args.use_cuda)
        planner.to(device)
        planner.eval()

        # criterion = torch_loss(learn_params['optim']['criterion'], reduction=learn_params['optim']['loss_reduction'])

        valid_task_loss_per_iter = []
        valid_cost_per_iter = []
        valid_num_iters = []
        valid_gp_error = []
        valid_avg_vel = []
        valid_avg_acc = []
        valid_avg_jerk = []
        valid_in_coll = []
        valid_avg_penetration = []
        valid_max_penetration = []
        valid_coll_intensity = []
        valid_constraint_violation = []

        with torch.no_grad():
            for i in idxs:
                sample = dataset[i]
                # print('Environment idx = %d'%i)
                im = sample['im'].to(device)
                sdf = sample['sdf'].to(device)
                start = sample['start'].to(device)
                goal = sample['goal'].to(device)
                th_opt = sample['th_opt'].to(device)
                start_conf = start[0, 0:dof]
                goal_conf = goal[0, 0:dof]
                th_init = straight_line_traj(start_conf, goal_conf,
                                             planner_params['total_time_sec'],
                                             planner_params['total_time_step'],
                                             dof, device)
                j = 0

                th_curr = th_init.unsqueeze(0)
                dtheta = torch.zeros_like(th_curr)
                eps_traj = torch.zeros(planner_params['total_time_step'] + 1,
                                       robot.nlinks, 1)
                eps_traj = eps_traj.unsqueeze(0).repeat(
                    th_curr.shape[0], 1, 1, 1)
                obsfactor.set_eps(eps_traj.unsqueeze(0))
                curr_hidden = None

                cost_per_iter = []
                task_loss_per_iter = []
                th_best = None
                best_task_loss = np.inf
                if args.render:
                    th_init_np = th_init.cpu().detach().numpy()
                    th_opt_np = th_opt.cpu().detach().numpy()
                    env = Env2D(env_params)
                    env.initialize_from_image(im[0], sdf[0])
                    path_init = [
                        th_init_np[i, 0:dof]
                        for i in xrange(planner_params['total_time_step'] + 1)
                    ]
                    path_opt = [
                        th_opt_np[i, 0:dof]
                        for i in xrange(planner_params['total_time_step'] + 1)
                    ]
                    env.initialize_plot(start_conf.cpu().numpy(),
                                        goal_conf.cpu().numpy())
                    env.plot_signed_distance_transform()
                    raw_input('Enter to start ...')
                    plt.show(block=False)

                while True:
                    # print("Current iteration = %d"%j)
                    if args.render:
                        th_curr_np = th_curr.cpu().detach().numpy()
                        path_curr = [
                            th_curr_np[0, i, 0:dof]
                            for i in xrange(planner_params['total_time_step'] +
                                            1)
                        ]
                        if j > 0: env.clear_edges()
                        env.plot_edge(
                            path_curr, color='blue'
                        )  #, linestyle='-', linewidth=0.01*j , alpha=1.0-(1.0/(j+0.0001)) )
                        plt.show(block=False)
                        time.sleep(0.002)
                        if args.step:
                            raw_input('Press enter for next step')

                    dtheta, curr_hidden, err_old, err_ext_old, qc_inv_traj, obscov_inv_traj, eps_traj = planner.step(
                        th_curr, start.unsqueeze(0), goal.unsqueeze(0),
                        im.unsqueeze(0), sdf.unsqueeze(0), dtheta, curr_hidden)
                    err_sg, err_gp, err_obs = planner.unweighted_errors_batch(
                        th_curr, sdf.unsqueeze(0))
                    task_loss = err_gp + args.obs_lambda * err_obs

                    #We only keep the best trajectory so far
                    if task_loss.item() < best_task_loss:
                        th_best = th_curr
                        best_task_loss = task_loss.item()

                    task_loss_per_iter.append(task_loss.item())
                    cost_per_iter.append(err_old.item())

                    th_old = th_curr
                    th_curr = th_curr + dtheta
                    th_new = th_curr
                    err_new = planner.error_batch(th_curr,
                                                  sdf.unsqueeze(0)).item()
                    err_ext_new = planner.error_ext_batch(
                        th_curr, sdf.unsqueeze(0)).item()

                    err_delta = err_new - err_old[0]
                    err_ext_delta = err_ext_new - err_ext_old[0]
                    # print('|dtheta| = %f, err = %f, err_ext = %f, err_delta = %f,\
                    #        |qc_inv| = %f, |obscov_inv| = %f'%(torch.norm(dtheta), err_old[0], err_delta, err_ext_delta,\
                    #                                           torch.norm(qc_inv_traj, p='fro', dim=(2,3)).mean(),\
                    #                                           torch.norm(obscov_inv_traj, p='fro', dim=(2,3)).mean()))

                    j = j + 1
                    if check_convergence(dtheta, j, torch.tensor(err_delta),
                                         optim_params['tol_err'],
                                         optim_params['tol_delta'],
                                         optim_params['max_iters']):
                        # print('Converged')
                        break
                th_final = th_best
                #########################METRICS##########################################
                # print th_final
                avg_vel, avg_acc, avg_jerk = smoothness_metrics(
                    th_final[0], planner_params['total_time_sec'],
                    planner_params['total_time_step'])
                gp_error, _, _ = gpfactor.get_error(th_final)
                obs_error, _ = obsfactor.get_error(th_final, sdf.unsqueeze(0))
                mse_gp = torch.mean(torch.sum(gp_error**2, dim=-1))
                in_coll, avg_penetration, max_penetration, coll_int = collision_metrics(
                    th_final[0], obs_error[0],
                    planner_params['total_time_sec'],
                    planner_params['total_time_step'])
                print('Trajectory in collision = ', in_coll)
                # print('MSE GP = {}, Average velocity = {}, average acc = {}, avg jerk=  {}'.format(mse_gp, avg_vel, avg_acc, avg_jerk))
                # print('In coll = {}, average penetration = {}, max penetration = {}, collision intensity =  {}'.format(in_coll,
                # avg_penetration,
                # max_penetration,
                # coll_int))
                constraint_violation = 0.0
                if use_vel_limits:  #planner_params['use_vel_limits']:
                    v_x_lim = gp_params['v_x']
                    v_y_lim = gp_params['v_y']
                    for i in xrange(th_final.shape[1]):
                        s = th_final[0][i]
                        v_x = s[2]
                        v_y = s[3]
                        if torch.abs(v_x) <= v_x_lim and torch.abs(
                                v_y) <= v_y_lim:
                            continue
                        else:
                            constraint_violation = constraint_violation + 1.0
                            print('Constraint violatrion!!!!!')
                constraint_violation = constraint_violation / (
                    th_final.shape[1] * 1.0)

                valid_gp_error.append(mse_gp.item())
                valid_avg_vel.append(avg_vel.item())
                valid_avg_acc.append(avg_acc.item())
                valid_avg_jerk.append(avg_jerk.item())
                valid_in_coll.append(in_coll)
                valid_avg_penetration.append(avg_penetration.item())
                valid_max_penetration.append(max_penetration.item())
                valid_coll_intensity.append(coll_int)
                valid_constraint_violation.append(constraint_violation)

                err_sg, err_gp, err_obs = planner.unweighted_errors_batch(
                    th_final, sdf.unsqueeze(0))
                task_loss = err_sg + err_gp + args.obs_lambda * err_obs

                task_loss_per_iter.append(task_loss.item())
                cost_per_iter.append(
                    planner.error_batch(th_final, sdf.unsqueeze(0)).item())

                valid_task_loss_per_iter.append(task_loss_per_iter)
                valid_cost_per_iter.append(cost_per_iter)
                valid_num_iters.append(j)

                if args.render:
                    th_final_np = th_final.cpu().detach().numpy()
                    path_final = [
                        th_final_np[0][i, 0:dof]
                        for i in xrange(planner_params['total_time_step'] + 1)
                    ]
                    env.clear_edges()
                    env.plot_edge(path_final)  #, linewidth=0.1*j)
                    plt.show(block=False)
                    raw_input('Press enter for next env')
                    env.close_plot()

        results_dict_sig = {}
        results_dict_sig['num_iters'] = valid_num_iters
        # results_dict_sig['cost_per_iter']        = valid_cost_per_iter
        results_dict_sig['gp_mse'] = valid_gp_error
        results_dict_sig['avg_vel'] = valid_avg_vel
        results_dict_sig['avg_acc'] = valid_avg_acc
        results_dict_sig['avg_jerk'] = valid_avg_jerk
        results_dict_sig['in_collision'] = valid_in_coll
        results_dict_sig['avg_penetration'] = valid_avg_penetration
        results_dict_sig['max_penetration'] = valid_max_penetration
        results_dict_sig['coll_intensity'] = valid_coll_intensity
        # results_dict_sig['task_loss_per_iter']   = valid_task_loss_per_iter
        results_dict_sig['constraint_violation'] = valid_constraint_violation
        results_dict[str(sigma_obs)] = results_dict_sig

        print('Avg unsolved = ', np.mean(valid_in_coll))

    print('Dumping results')
    filename = 'sensitivity_results.yaml'
    # else: filename = args.model_file+"_valid_results.yaml"
    with open(dataset_folders[0] + '/' + filename, 'w') as fp:
        yaml.dump(results_dict, fp)
Example #2
0
patch_size_safety = int(np.ceil(
    safety_distance.item() / cell_size *
    1.0))  #Padding size for sdf and min distance of start goal from obstacles
patch_size_robot = int(np.ceil(robot_radius.item() / cell_size * 1.0))

#Generate and save training data
dataset_number = datasets[args.dataset_type]
map_dim = (args.im_size, args.im_size)

env_params = {
    'x_lims': env_data['x_lims'],
    'y_lims': env_data['y_lims'],
    'padlen': patch_size_safety
}
env = Env2D(env_params)
robot = PointRobot2D(robot_data['sphere_radius'])

env_number = 0
while env_number < args.num_train:
    print('Creating env number %d' % env_number)
    #Sample start and goal in meters
    far_enough = False
    while not far_enough:
        print("Sampling valid start and goal")
        start_xs = (sgmax_x - sgmin_x) * torch.rand(args.probs_per_env,
                                                    1) + sgmin_x
        start_ys = (sgmax_y - sgmin_y) * torch.rand(args.probs_per_env,
                                                    1) + sgmin_y
        goal_xs = (sgmax_x - sgmin_x) * torch.rand(args.probs_per_env,
                                                   1) + sgmin_x
        goal_ys = (sgmax_y - sgmin_y) * torch.rand(args.probs_per_env,
Example #3
0
    sample_batch = sample
    if i == 1:
        break
env_params = {'x_lims': env_data['x_lims'], 'y_lims': env_data['y_lims']}
imb = sample_batch['im']
res = (env_params['x_lims'][1] - env_params['x_lims'][0]) / (imb[0].shape[-1] *
                                                             1.)
sdfb = sample_batch['sdf'] * res
startb = sample_batch['start']
goalb = sample_batch['goal']
th_optb = sample_batch['th_opt']
# env_params_b = sample_batch['env_params']
#2D Point robot model
total_time_step = planner_params['total_time_step']
robot = PointRobot2D(robot_data['sphere_radius'][0],
                     batch_size,
                     total_time_step + 1,
                     use_cuda=use_cuda)

#Initial trajectories are just straight lines from start to goal
total_time_sec = planner_params['total_time_sec']
dof = planner_params['dof']
th_init_tmp = torch.zeros(
    (batch_size, int(total_time_step) + 1, planner_params['state_dim']),
    device=device)  #Straight   line at constant velocity

for j in xrange(batch_size):
    avg_vel = (goalb[j][0, 0:dof] - startb[j][0, 0:dof]) / total_time_sec
    for i in range(int(total_time_step) + 1):
        th_init_tmp[j][i, 0:2] = startb[j][0, 0:dof] * (
            total_time_step - i) * 1. / total_time_step * 1. + goalb[j][
                0,
Example #4
0
device = torch.device('cuda') if use_cuda else torch.device('cpu')


plan_param_file = os.path.abspath('gpmp2_2d_params.yaml')
robot_param_file = os.path.abspath('robot_2d.yaml')
env_param_file = os.path.abspath('env_2d_params.yaml')
ENV_FILE = os.path.abspath("../diff_gpmp2/env/simple_2d/1.png")

#Load the environment and planning parameters
env_data, planner_params, gp_params, obs_params, optim_params, robot_data = load_params(plan_param_file, robot_param_file, env_param_file, device)
env_params = {'x_lims': env_data['x_lims'], 'y_lims': env_data['y_lims']}
env = Env2D(env_params, use_cuda=use_cuda)
env.initialize_from_file(ENV_FILE)

#2D Point robot model
robot = PointRobot2D(robot_data['sphere_radius'][0], use_cuda=use_cuda)

start_conf = torch.tensor([[-4., -4.]], device=device)
start_vel = torch.tensor([[0., 0.]], device=device)
goal_conf = torch.tensor([[4., 4.]], device=device)#[17, 14])
goal_vel = torch.tensor([[0., 0.]], device=device)
avg_vel = (goal_conf - start_conf)/planner_params['total_time_sec']
start = torch.cat((start_conf, start_vel), dim=1)
goal = torch.cat((goal_conf, goal_vel), dim=1)
total_time_step = planner_params['total_time_step']
th_init_tmp = torch.zeros((int(total_time_step)+1, planner_params['state_dim']), device=device) #Straight   line at constant velocity

for i in range(int(total_time_step)+1):
  th_init_tmp[i, 0:2] = start_conf*(total_time_step - i)*1./total_time_step*1. + goal_conf * i*1./total_time_step*1. #+ np.array([0., 5.0])
  th_init_tmp[i, 2:4] = avg_vel
Example #5
0
from diff_gpmp2.robot_models import PointRobot2D
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

np.set_printoptions(threshold=np.nan, linewidth=np.inf)
pp = pprint.PrettyPrinter()
#Load the environment
ENV_FILE = os.path.abspath("../diff_gpmp2/env/test_env.png")
env = Env2D()
env_params = dict()
env_params['y_lims'] = [-20, 20]
env_params['x_lims'] = [-20, 20]
env.initialize(ENV_FILE, env_params)
env.calculate_signed_distance_transform()
env.plot_signed_distance_transform()

#2D Point robot model
sphere_radius = 2
robot = PointRobot2D(sphere_radius)

dof = 2
state_dim = 4
eps = 2
cov = np.eye(robot.nlinks)

obs_factor = ObstacleFactor(state_dim, cov, eps, env, robot)
err, H = obs_factor.get_error(np.array([13, 5, 0, 0]))


plt.show()
Example #6
0
def test(model, criterion, gp_prior, valid_loader, valid_idxs, env_data,
         planner_params, gp_params, obs_params, optim_params, robot_data,
         learn_params, model_folder, results_folder, use_cuda, render):
    device = torch.device('cuda') if use_cuda else torch.device('cpu')
    if use_cuda: torch.set_default_tensor_type(torch.cuda.DoubleTensor)
    else: torch.set_default_tensor_type(torch.DoubleTensor)
    #Create planner object
    env_params = {'x_lims': env_data['x_lims'], 'y_lims': env_data['y_lims']}
    if robot_data['type'] == 'point_robot_2d':
        robot = PointRobot2D(robot_data['sphere_radius'],
                             use_cuda=args.use_cuda)
    x_min = env_data['x_lims'][0]
    x_max = env_data['x_lims'][1]
    cell_size = (x_max - x_min) / learn_params['im_size'] * 1.0
    dof = planner_params['dof']
    avg_loss = 0.0
    avg_solved = 0.0
    avg_gpmse = 0.0
    num_envs = len(valid_idxs)
    model.eval()
    with torch.no_grad():
        for i in valid_idxs:
            # print('Environment idx = %d'%i)
            sample = valid_loader.dataset[i]
            im = sample['im'].unsqueeze(0).to(device)
            sdf = sample['sdf'].unsqueeze(0).to(device)
            start = sample['start'].unsqueeze(0).to(device)
            goal = sample['goal'].unsqueeze(0).to(device)
            target = sample['th_opt'].unsqueeze(0).to(device)
            start_conf = start[:, :, 0:dof]
            goal_conf = goal[:, :, 0:dof]
            th_init = straight_line_trajb(start_conf, goal_conf,
                                          planner_params['total_time_sec'],
                                          planner_params['total_time_step'],
                                          dof, device)
            data = torch.cat((im, sdf), dim=1)
            th_initx = torch.index_select(th_init, -1, torch.tensor(0))
            th_inity = torch.index_select(th_init, -1, torch.tensor(1))
            th_initpos = torch.cat((th_initx, th_inity), dim=-1)
            output = model(data,
                           th_initpos)  #start[:,:,0:dof], goal[:,:,0:dof])
            avg_loss = avg_loss + one_step_loss(output, target - th_init,
                                                criterion, learn_params,
                                                gp_prior).item()
            th_final = th_initpos + output
            avg_solved = avg_solved + check_solved(
                th_final, sdf, robot_data['sphere_radius'], cell_size,
                env_params, use_cuda)
            # avg_gpmse = avg_gpmse + smoothness_error(output, gp_prior)
            if args.render:
                env = Env2D(env_params)
                env.initialize_from_image(im[0, 0].cpu(), sdf[0, 0].cpu())
                env.initialize_plot(start[0, 0, 0:2].cpu().numpy(),
                                    goal[0, 0, 0:2].cpu().numpy())
                env.plot_signed_distance_transform()
                th_final_np = th_final[0].cpu().detach().numpy()
                th_opt_np = target[0].cpu().detach().numpy()
                path_final = [
                    th_final_np[i, 0:planner_params['dof']]
                    for i in xrange(planner_params['total_time_step'] + 1)
                ]
                path_opt = [
                    th_opt_np[i, 0:planner_params['dof']]
                    for i in xrange(planner_params['total_time_step'] + 1)
                ]
                env.plot_edge(
                    path_final, color='blue'
                )  #, linestyle='-', linewidth=0.01*j , alpha=1.0-(1.0/(j+0.0001)) )
                env.plot_edge(
                    path_opt, color='red', linestyle='--'
                )  #, linewidth=0.01*j , alpha=1.0-(1.0/(j+0.0001)) )
                plt.show(block=False)
                # time.sleep(0.002)
                raw_input('Press enter for next step')
                env.close_plot()

    return avg_loss / num_envs * 1.0, avg_solved / num_envs * 1.0, avg_gpmse / num_envs * 1.0
Example #7
0
def generate_trajs_and_save(folder,
                            num_envs,
                            probs_per_env,
                            env_data,
                            planner_params,
                            gp_params,
                            obs_params,
                            optim_params,
                            robot_data,
                            out_folder_name,
                            rrt_star_init=False,
                            fix_start_goal=False):
    for i in xrange(num_envs):
        if env_data['dim'] == 2:
            env_params = {
                'x_lims': env_data['x_lims'],
                'y_lims': env_data['y_lims']
            }
            env = Env2D(env_params)
            if robot_data['type'] == 'point_robot':
                robot = PointRobot2D(robot_data['sphere_radius'])
            # print   robot.get_sphere_radii()
            im = plt.imread(folder + "/im_sdf/" + str(i) + "_im.png")
            sdf = np.load(folder + "/im_sdf/" + str(i) + "_sdf.npy")
            env.initialize_from_image(im, sdf)
            imp = torch.tensor(im, device=device)
            sdfp = torch.tensor(sdf, device=device)

        for j in xrange(probs_per_env):
            planner = DiffGPMP2Planner(gp_params,
                                       obs_params,
                                       planner_params,
                                       optim_params,
                                       env_params,
                                       robot,
                                       use_cuda=use_cuda)
            start, goal, th_init = generate_start_goal(
                env_params, planner_params, env_data['dim'], j, env, robot,
                obs_params, rrt_star_init, fix_start_goal)
            th_final,_, err_init, err_final, err_per_iter, err_ext_per_iter, k, time_taken = \
                                                                            planner.forward(th_init.unsqueeze(0), start.unsqueeze(0), goal.unsqueeze(0), imp.unsqueeze(0).unsqueeze(0), sdfp.unsqueeze(0).unsqueeze(0))
            print('Num iterations = %d, Time taken %f' % (k[0], time_taken[0]))

            path_init = []
            path_final = []

            start_np = start.cpu().detach().numpy()[0]
            goal_np = goal.cpu().detach().numpy()[0]
            th_init_np = th_init.cpu().detach().numpy()
            th_final_np = th_final[0].cpu().detach().numpy()
            out_folder = os.path.join(folder, out_folder_name)
            if not os.path.exists(out_folder):
                os.makedirs(out_folder)
            out_path = out_folder + "/" + "env_" + str(i) + "_prob_" + str(j)
            np.savez(out_path,
                     start=start_np,
                     goal=goal_np,
                     th_opt=th_final_np)

    print('Saving meta data')
    with open(os.path.join(folder, "meta.yaml"), 'w') as fp:
        d = {
            'num_envs': num_envs,
            'probs_per_env': probs_per_env,
            'env_params': env_params,
            'im_size': args.im_size
        }
        yaml.dump(d, fp)
Example #8
0
    learn_param_file = os.path.join(input_folder,
                                    args.learn_param_file + pfiletype)

    model_folder = os.path.join(input_folder,
                                "models")  #folder to store learnt models in
    results_folder = os.path.join(input_folder, "results")
    if not os.path.exists(model_folder): os.makedirs(model_folder)
    if not os.path.exists(results_folder): os.makedirs(results_folder)
    env_data, planner_params, gp_params, obs_params,\
    optim_params, robot_data, learn_params = load_params_learn(plan_param_file, robot_param_file, env_param_file,
                                                               learn_param_file, device)
    print(optim_params, robot_data, learn_params, env_data, planner_params,
          gp_params, obs_params)
    env_params = {'x_lims': env_data['x_lims'], 'y_lims': env_data['y_lims']}
    if robot_data['type'] == 'point_robot_2d':
        robot = PointRobot2D(robot_data['sphere_radius'],
                             use_cuda=args.use_cuda)

    train_loader, valid_loader, train_idxs, valid_idxs = train_valid_split(
        dataset_folders,
        learn_params['data']['valid_size'],
        learn_params['data']['expert'],
        learn_params['optim']['batch_size'],
        learn_params['data']['shuffle'],
        learn_params['data']['num_workers'],
        learn_params['data']['num_train_envs'],
        learn_params['data']['num_train_env_probs'],
        learn_params['data']['pin_memory'],
    )

    learn_params['im_size'] = train_loader.dataset.meta_data[0]['im_size']
    num_traj_states = planner_params['total_time_step'] + 1