def mpc_w_learned_dynamics(config,
                           train_dir,
                           mpc_dir,
                           state_dict_path=None,
                           keypoint_observation=False):

    # set random seed for reproduction
    set_seed(config['train']['random_seed'])

    tee = Tee(os.path.join(mpc_dir, 'mpc.log'), 'w')

    print(config)

    use_gpu = torch.cuda.is_available()
    '''
    model
    '''
    if config['dynamics']['model_type'] == 'mlp':
        model_dy = DynaNetMLP(config)
    else:
        raise AssertionError("Unknown model type %s" %
                             config['dynamics']['model_type'])

    # print model #params
    print("model #params: %d" % count_trainable_parameters(model_dy))

    if state_dict_path is None:
        if config['mpc']['mpc_dy_epoch'] == -1:
            state_dict_path = os.path.join(train_dir, 'net_best_dy.pth')
        else:
            state_dict_path = os.path.join(
                train_dir, 'net_dy_epoch_%d_iter_%d.pth' % \
                (config['mpc']['mpc_dy_epoch'], config['mpc']['mpc_dy_iter']))

        print("Loading saved ckp from %s" % state_dict_path)

    model_dy.load_state_dict(torch.load(state_dict_path))
    model_dy.eval()

    if use_gpu:
        model_dy.cuda()

    criterionMSE = nn.MSELoss()

    # generate action/observation functions
    action_function = ActionFunctionFactory.function_from_config(config)
    observation_function = ObservationFunctionFactory.function_from_config(
        config)

    # planner
    planner = planner_from_config(config)
    '''
    env
    '''
    # set up goal
    obs_goals = np.array([[
        262.9843, 267.3102, 318.9369, 351.1229, 360.2048, 323.5128, 305.6385,
        240.4460, 515.4230, 347.8708
    ],
                          [
                              381.8694, 273.6327, 299.6685, 331.0925, 328.7724,
                              372.0096, 411.0972, 314.7053, 517.7299, 268.4953
                          ],
                          [
                              284.8728, 275.7985, 374.0677, 320.4990, 395.4019,
                              275.4633, 306.2896, 231.4310, 507.0849, 312.4057
                          ],
                          [
                              313.1638, 271.4258, 405.0255, 312.2325, 424.7874,
                              266.3525, 333.6973, 225.7708, 510.1232, 305.3802
                          ],
                          [
                              308.6859, 270.9629, 394.2789, 323.2781, 419.7905,
                              280.1602, 333.8901, 228.1624, 519.1964, 321.5318
                          ],
                          [
                              386.8067, 284.8947, 294.2467, 323.2223, 313.3221,
                              368.9970, 405.9415, 330.9298, 495.9970, 268.9920
                          ],
                          [
                              432.0219, 299.6021, 340.8581, 339.4676, 360.2354,
                              384.5515, 451.4394, 345.2190, 514.6357, 291.2043
                          ],
                          [
                              351.3389, 264.5325, 267.5279, 318.2321, 293.7460,
                              360.0423, 378.4428, 306.9586, 516.4390, 259.7810
                          ],
                          [
                              521.1902, 254.0693, 492.7884, 349.7861, 539.6320,
                              364.5190, 569.2258, 268.8824, 506.9431, 286.9752
                          ],
                          [
                              264.8554, 275.9547, 338.1317, 345.3435, 372.7012,
                              308.4648, 299.3454, 239.9245, 506.2117, 373.8413
                          ]])

    for mpc_idx in range(config['mpc']['num_episodes']):
        if keypoint_observation:
            mpc_episode_keypoint_observation(config,
                                             mpc_idx,
                                             model_dy,
                                             mpc_dir,
                                             planner,
                                             obs_goals[mpc_idx],
                                             action_function,
                                             observation_function,
                                             use_gpu=use_gpu)
        else:
            # not supported for now
            raise AssertionError("currently only support keypoint observation")
Exemple #2
0
def eval_dynamics(config,
                  train_dir,
                  eval_dir,
                  state_dict_path=None,
                  keypoint_observation=False,
                  debug=False,
                  render_human=False):

    # set random seed for reproduction
    set_seed(config['train']['random_seed'])

    tee = Tee(os.path.join(eval_dir, 'eval.log'), 'w')

    print(config)

    use_gpu = torch.cuda.is_available()
    '''
    model
    '''
    model_dy = DynaNetMLP(config)

    # print model #params
    print("model #params: %d" % count_trainable_parameters(model_dy))

    if state_dict_path is None:
        if config['eval']['eval_dy_epoch'] == -1:
            state_dict_path = os.path.join(train_dir, 'net_best_dy.pth')
        else:
            state_dict_path = os.path.join(
                train_dir, 'net_dy_epoch_%d_iter_%d.pth' % \
                (config['eval']['eval_dy_epoch'], config['eval']['eval_dy_iter']))

        print("Loading saved ckp from %s" % state_dict_path)

    model_dy.load_state_dict(torch.load(state_dict_path))
    model_dy.eval()

    if use_gpu:
        model_dy.cuda()

    criterionMSE = nn.MSELoss()
    bar = ProgressBar()

    st_idx = config['eval']['eval_st_idx']
    ed_idx = config['eval']['eval_ed_idx']

    # load the data
    episodes = load_episodes_from_config(config)

    # generate action/observation functions
    action_function = ActionFunctionFactory.function_from_config(config)
    observation_function = ObservationFunctionFactory.function_from_config(
        config)

    dataset = MultiEpisodeDataset(config,
                                  action_function=action_function,
                                  observation_function=observation_function,
                                  episodes=episodes,
                                  phase="valid")

    episode_names = dataset.get_episode_names()
    episode_names.sort()

    num_episodes = None
    # for backwards compatibility
    if "num_episodes" in config["eval"]:
        num_episodes = config["eval"]["num_episodes"]
    else:
        num_episodes = 10

    episode_list = []
    if debug:
        episode_list = [episode_names[0]]
    else:
        episode_list = episode_names[:num_episodes]

    for roll_idx, episode_name in enumerate(episode_list):
        print("episode_name", episode_name)
        if keypoint_observation:
            eval_episode_keypoint_observations(config,
                                               dataset,
                                               episode_name,
                                               roll_idx,
                                               model_dy,
                                               eval_dir,
                                               start_idx=9,
                                               n_prediction=30,
                                               render_human=render_human)
        else:
            eval_episode(config,
                         dataset,
                         episode_name,
                         roll_idx,
                         model_dy,
                         eval_dir,
                         start_idx=9,
                         n_prediction=30,
                         render_human=render_human)