Exemplo n.º 1
0
def configure_her(params):
    env = cached_make_env(params['make_env'])
    env.reset()

    def reward_fun(ag_2, g, o, **kwargs):  # vectorized

        if env.relative_goal:
            dif = o[:, -2:]
        else:
            # dif = o[:, :2] - o[:, -2:]
            dif = ag_2 - g
        if env.distance_metric == 'L1':
            goal_distance = np.linalg.norm(dif, ord=1, axis=-1)
        elif env.distance_metric == 'L2':
            goal_distance = np.linalg.norm(dif, ord=2, axis=-1)
        elif callable(env.distance_metric):
            goal_distance = env.distance_metric(ag_2, g)
        else:
            raise NotImplementedError('Unsupported distance metric type.')
        if env.only_feasible:
            ret = np.logical_and(goal_distance < env.terminal_eps, [env.is_feasible(g_ind) for g_ind in g]) * env.goal_weight \
                    - env.extend_dist_rew_weight * goal_distance

        else:
            ret = (goal_distance < env.terminal_eps) * env.goal_weight - env.extend_dist_rew_weight * goal_distance

        return ret
        # return -goal_distance

    # Prepare configuration for HER.
    her_params = {
        'reward_fun': reward_fun,
    }
    for name in ['replay_strategy', 'replay_k', 'discriminator', 'gail_weight', 'sample_g_first', 'zero_action_p',
                 'dis_bound', 'two_rs', 'with_termination']:
        if name in params:
            her_params[name] = params[name]
            params['_' + name] = her_params[name]
            del params[name]
    if 'nearby_action_penalty' in params:
        sample_her_transitions = make_sample_her_transitions(**her_params, env=env,
                                                         nearby_action_penalty=params['nearby_action_penalty'],
                                                         nearby_p = params['nearby_p'],
                                                         perturb_scale=params['perturb_scale'],
                                                         cells_apart=params['cells_apart'],
                                                         perturb_to_feasible=params['perturb_to_feasible'])
    else:
        sample_her_transitions = make_sample_her_transitions(**her_params, env=env)

    return sample_her_transitions
Exemplo n.º 2
0
def configure_her(params):
    """
    configure hindsight experience replay

    :param params: (dict) input parameters
    :return: (function (dict, int): dict) returns a HER update function for replay buffer batch
    """
    env = cached_make_env(params['make_env'])
    env.reset()

    def reward_fun(achieved_goal, goal, info):  # vectorized
        return env.compute_reward(achieved_goal=achieved_goal,
                                  desired_goal=goal,
                                  info=info)

    # Prepare configuration for HER.
    her_params = {
        'reward_fun': reward_fun,
    }
    for name in ['replay_strategy', 'replay_k']:
        her_params[name] = params[name]
        params['_' + name] = her_params[name]
        del params[name]
    sample_her_transitions = make_sample_her_transitions(**her_params)

    return sample_her_transitions
def configure_her(params):
    env = cached_make_env(params['make_env'])
    env.reset()

    def reward_fun(ag_2, g, info):  # vectorized
        return env.compute_reward(achieved_goal=ag_2,
                                  desired_goal=g,
                                  info=info)

    # Prepare configuration for HER.
    her_params = {
        'reward_fun': reward_fun,
    }
    for name in ['replay_strategy', 'replay_k']:
        her_params[name] = params[name]
        params['_' + name] = her_params[name]
        del params[name]

    if params['prioritization'] == 'energy':
        sample_her_transitions = make_sample_her_transitions_energy(
            **her_params)
    elif params['prioritization'] == 'tderror':
        sample_her_transitions = make_sample_her_transitions_prioritized_replay(
            **her_params)
    else:
        sample_her_transitions = make_sample_her_transitions(**her_params)

    return sample_her_transitions
Exemplo n.º 4
0
def configure_her(params):
    env = cached_make_env(params['make_env'])
    env.reset()
    def reward_fun(ag_2, g, info):  # vectorized
        return env.compute_reward(achieved_goal=ag_2, desired_goal=g, info=info)

    # Prepare configuration for HER.
    her_params = {
        'reward_fun': reward_fun,
    }
    for name in ['replay_strategy', 'replay_k']:
        her_params[name] = params[name]
        params['_' + name] = her_params[name]
        del params[name]
    sample_her_transitions = make_sample_her_transitions(**her_params)

    return sample_her_transitions
Exemplo n.º 5
0
def configure_her(params):
    env = cached_make_env(params['make_env'])
    env.reset()

    def reward_fun(ag_2, g, info, th):  # vectorized
        return env.compute_reward(achieved_goal=ag_2,
                                  desired_goal=g,
                                  info=info,
                                  threshold=th)

    # Prepare configuration for HER.
    her_params = {
        'reward_fun': reward_fun,
    }
    for name in ['replay_strategy', 'replay_k']:
        her_params[name] = params[name]
        params['_' + name] = her_params[name]
        del params[name]
    sample_her_transitions = make_sample_her_transitions(**her_params)

    return sample_her_transitions
Exemplo n.º 6
0
def configure_her(params):
    env = cached_make_env(params['make_env'])
    env.reset()

    def reward_fun(ag_2, g, info):  # vectorized
        return env.compute_reward(achieved_goal=ag_2,
                                  desired_goal=g,
                                  info=info)

    her_params = {
        'reward_fun': reward_fun,
    }
    for name in [
            'replay_strategy', 'replay_k', 'mi_w_schedule', 'et_w_schedule',
            'mi_prioritization'
    ]:
        her_params[name] = params[name]
        if not (name in ['mi_prioritization']):
            params['_' + name] = her_params[name]
            del params[name]

    sample_her_transitions = make_sample_her_transitions(**her_params)

    return sample_her_transitions