def mpc_w_learned_dynamics(config, train_dir, mpc_dir, state_dict_path=None, keypoint_observation=False): # set random seed for reproduction set_seed(config['train']['random_seed']) tee = Tee(os.path.join(mpc_dir, 'mpc.log'), 'w') print(config) use_gpu = torch.cuda.is_available() ''' model ''' if config['dynamics']['model_type'] == 'mlp': model_dy = DynaNetMLP(config) else: raise AssertionError("Unknown model type %s" % config['dynamics']['model_type']) # print model #params print("model #params: %d" % count_trainable_parameters(model_dy)) if state_dict_path is None: if config['mpc']['mpc_dy_epoch'] == -1: state_dict_path = os.path.join(train_dir, 'net_best_dy.pth') else: state_dict_path = os.path.join( train_dir, 'net_dy_epoch_%d_iter_%d.pth' % \ (config['mpc']['mpc_dy_epoch'], config['mpc']['mpc_dy_iter'])) print("Loading saved ckp from %s" % state_dict_path) model_dy.load_state_dict(torch.load(state_dict_path)) model_dy.eval() if use_gpu: model_dy.cuda() criterionMSE = nn.MSELoss() # generate action/observation functions action_function = ActionFunctionFactory.function_from_config(config) observation_function = ObservationFunctionFactory.function_from_config( config) # planner planner = planner_from_config(config) ''' env ''' # set up goal obs_goals = np.array([[ 262.9843, 267.3102, 318.9369, 351.1229, 360.2048, 323.5128, 305.6385, 240.4460, 515.4230, 347.8708 ], [ 381.8694, 273.6327, 299.6685, 331.0925, 328.7724, 372.0096, 411.0972, 314.7053, 517.7299, 268.4953 ], [ 284.8728, 275.7985, 374.0677, 320.4990, 395.4019, 275.4633, 306.2896, 231.4310, 507.0849, 312.4057 ], [ 313.1638, 271.4258, 405.0255, 312.2325, 424.7874, 266.3525, 333.6973, 225.7708, 510.1232, 305.3802 ], [ 308.6859, 270.9629, 394.2789, 323.2781, 419.7905, 280.1602, 333.8901, 228.1624, 519.1964, 321.5318 ], [ 386.8067, 284.8947, 294.2467, 323.2223, 313.3221, 368.9970, 405.9415, 330.9298, 495.9970, 268.9920 ], [ 432.0219, 299.6021, 340.8581, 339.4676, 360.2354, 384.5515, 451.4394, 345.2190, 514.6357, 291.2043 ], [ 351.3389, 264.5325, 267.5279, 318.2321, 293.7460, 360.0423, 378.4428, 306.9586, 516.4390, 259.7810 ], [ 521.1902, 254.0693, 492.7884, 349.7861, 539.6320, 364.5190, 569.2258, 268.8824, 506.9431, 286.9752 ], [ 264.8554, 275.9547, 338.1317, 345.3435, 372.7012, 308.4648, 299.3454, 239.9245, 506.2117, 373.8413 ]]) for mpc_idx in range(config['mpc']['num_episodes']): if keypoint_observation: mpc_episode_keypoint_observation(config, mpc_idx, model_dy, mpc_dir, planner, obs_goals[mpc_idx], action_function, observation_function, use_gpu=use_gpu) else: # not supported for now raise AssertionError("currently only support keypoint observation")
def eval_dynamics(config, train_dir, eval_dir, state_dict_path=None, keypoint_observation=False, debug=False, render_human=False): # set random seed for reproduction set_seed(config['train']['random_seed']) tee = Tee(os.path.join(eval_dir, 'eval.log'), 'w') print(config) use_gpu = torch.cuda.is_available() ''' model ''' model_dy = DynaNetMLP(config) # print model #params print("model #params: %d" % count_trainable_parameters(model_dy)) if state_dict_path is None: if config['eval']['eval_dy_epoch'] == -1: state_dict_path = os.path.join(train_dir, 'net_best_dy.pth') else: state_dict_path = os.path.join( train_dir, 'net_dy_epoch_%d_iter_%d.pth' % \ (config['eval']['eval_dy_epoch'], config['eval']['eval_dy_iter'])) print("Loading saved ckp from %s" % state_dict_path) model_dy.load_state_dict(torch.load(state_dict_path)) model_dy.eval() if use_gpu: model_dy.cuda() criterionMSE = nn.MSELoss() bar = ProgressBar() st_idx = config['eval']['eval_st_idx'] ed_idx = config['eval']['eval_ed_idx'] # load the data episodes = load_episodes_from_config(config) # generate action/observation functions action_function = ActionFunctionFactory.function_from_config(config) observation_function = ObservationFunctionFactory.function_from_config( config) dataset = MultiEpisodeDataset(config, action_function=action_function, observation_function=observation_function, episodes=episodes, phase="valid") episode_names = dataset.get_episode_names() episode_names.sort() num_episodes = None # for backwards compatibility if "num_episodes" in config["eval"]: num_episodes = config["eval"]["num_episodes"] else: num_episodes = 10 episode_list = [] if debug: episode_list = [episode_names[0]] else: episode_list = episode_names[:num_episodes] for roll_idx, episode_name in enumerate(episode_list): print("episode_name", episode_name) if keypoint_observation: eval_episode_keypoint_observations(config, dataset, episode_name, roll_idx, model_dy, eval_dir, start_idx=9, n_prediction=30, render_human=render_human) else: eval_episode(config, dataset, episode_name, roll_idx, model_dy, eval_dir, start_idx=9, n_prediction=30, render_human=render_human)