コード例 #1
0
ファイル: test_plotting.py プロジェクト: fdamken/SimuRLacra
def test_rollout_based(env: SimEnv, policy: Policy):
    ro = rollout(env, policy, record_dts=True)

    if isinstance(policy, LinearPolicy):
        plot_features(ro, policy)
    elif isinstance(policy, PotentialBasedPolicy):
        plot_potentials(ro)
    else:
        plot_observations_actions_rewards(ro)
        plot_observations(ro)
        plot_actions(ro, env)
        plot_rewards(ro)
        draw_dts(ro.dts_policy, ro.dts_step, ro.dts_remainder, y_top_lim=5)
コード例 #2
0
    def policy_fcn(t: float):
        return [0.7, 1, 0, 0.1, 0.5, 0.5]

    policy = TimePolicy(env.spec, policy_fcn, dt)

    # Simulate and plot potentials
    print(env.obs_space.labels)
    return rollout(env,
                   policy,
                   render_mode=RenderMode(video=True),
                   stop_on_done=False)


if __name__ == '__main__':
    # Choose setup
    setup_type = 'ik'  # ik, or activation
    common_hparam = dict(
        dt=0.01,
        max_steps=1200,
        max_dist_force=None,
        physics_engine='Bullet',  # Bullet or Vortex
        graph_file_name=
        'gPlanarInsert6Link.xml',  # gPlanarInsert6Link.xml or gPlanarInsert5Link.xml
    )

    if setup_type == 'ik':
        ro = ik_control_variant(**common_hparam)
    elif setup_type == 'activation':
        ro = task_activation_variant(**common_hparam)
        plot_potentials(ro)
コード例 #3
0
def adn_variant(dt,
                max_steps,
                max_dist_force,
                physics_engine,
                normalize_obs=True,
                obsnorm_cpp=True):
    pyrado.set_seed(1001)

    # Explicit normalization bounds
    elb = {
        'EffectorLoadCell_Fx': -100.,
        'EffectorLoadCell_Fz': -100.,
        'Effector_Xd': -1,
        'Effector_Zd': -1,
        'GD_DS0d': -1,
        'GD_DS1d': -1,
        'GD_DS2d': -1,
    }
    eub = {
        'GD_DS0': 3.,
        'GD_DS1': 3,
        'GD_DS2': 3,
        'EffectorLoadCell_Fx': 100.,
        'EffectorLoadCell_Fz': 100.,
        'Effector_Xd': .5,
        'Effector_Zd': .5,
        'GD_DS0d': .5,
        'GD_DS1d': .5,
        'GD_DS2d': .5,
        'PredCollCost_h50': 1000.
    }

    extra_kwargs = {}
    if normalize_obs and obsnorm_cpp:
        extra_kwargs['normalizeObservations'] = True
        extra_kwargs['obsNormOverrideLower'] = elb
        extra_kwargs['obsNormOverrideUpper'] = eub

    # Set up environment
    env = Planar3LinkTASim(physicsEngine=physics_engine,
                           dt=dt,
                           max_steps=max_steps,
                           max_dist_force=max_dist_force,
                           collisionAvoidanceIK=True,
                           taskCombinationMethod='sum',
                           **extra_kwargs)

    if normalize_obs and not obsnorm_cpp:
        env = ObsNormWrapper(env, explicit_lb=elb, explicit_ub=eub)

    # Set up random policy
    policy_hparam = dict(
        tau_init=0.2,
        activation_nonlin=to.sigmoid,
        potentials_dyn_fcn=pd_cubic,
    )
    policy = ADNPolicy(spec=env.spec, dt=dt, **policy_hparam)
    print_cbt('Running ADNPolicy with random initialization', 'c', bright=True)

    # Simulate and plot potentials
    ro = rollout(env,
                 policy,
                 render_mode=RenderMode(video=True),
                 stop_on_done=True)
    plot_potentials(ro)

    return ro
コード例 #4
0
ファイル: rollout.py プロジェクト: fdamken/SimuRLacra
def after_rollout_query(
    env: Env, policy: Policy, rollout: StepSequence
) -> Tuple[bool, Optional[np.ndarray], Optional[dict]]:
    """
    Ask the user what to do after a rollout has been animated.

    :param env: environment used for the rollout
    :param policy: policy used for the rollout
    :param rollout: collected data from the rollout
    :return: done flag, initial state, and domain parameters
    """
    # Fist entry contains hotkey, second the info text
    options = [
        ["C", "continue simulation (with domain randomization)"],
        ["N", "set domain parameters to nominal values, and continue"],
        ["F", "fix the initial state"],
        ["I", "print information about environment (including randomizer), and policy"],
        ["S", "set a domain parameter explicitly"],
        ["P", "plot all observations, actions, and rewards"],
        ["PS [indices]", "plot all states, or selected ones by passing separated integers"],
        ["PO [indices]", "plot all observations, or selected ones by passing separated integers"],
        ["PA", "plot actions"],
        ["PR", "plot rewards"],
        ["PF", "plot features (for linear policy)"],
        ["PPOT", "plot potentials, stimuli, and actions (for potential-based policies)"],
        ["PDT", "plot time deltas (profiling of a real system)"],
        ["E", "exit"],
    ]

    # Ask for user input
    ans = input(tabulate(options, tablefmt="simple") + "\n").lower()

    if ans == "c" or ans == "":
        # We don't have to do anything here since the env will be reset at the beginning of the next rollout
        return False, None, None

    elif ans == "f":
        try:
            if isinstance(inner_env(env), RealEnv):
                raise pyrado.TypeErr(given=inner_env(env), expected_type=SimEnv)
            elif isinstance(inner_env(env), SimEnv):
                # Get the user input
                usr_inp = input(
                    f"Enter the {env.obs_space.flat_dim}-dim initial state "
                    f"(format: each dim separated by a whitespace):\n"
                )
                state = list(map(float, usr_inp.split()))
                if isinstance(state, list):
                    state = np.array(state)
                    if state.shape != env.obs_space.shape:
                        raise pyrado.ShapeErr(given=state, expected_match=env.obs_space)
                else:
                    raise pyrado.TypeErr(given=state, expected_type=list)
                return False, state, {}
        except (pyrado.TypeErr, pyrado.ShapeErr):
            return after_rollout_query(env, policy, rollout)

    elif ans == "n":
        # Get nominal domain parameters
        if isinstance(inner_env(env), SimEnv):
            dp_nom = inner_env(env).get_nominal_domain_param()
            if typed_env(env, ActDelayWrapper) is not None:
                # There is an ActDelayWrapper in the env chain
                dp_nom["act_delay"] = 0
        else:
            dp_nom = None
        return False, None, dp_nom

    elif ans == "i":
        # Print the information and return to the query
        print(env)
        if hasattr(env, "randomizer"):
            print(env.randomizer)
        print(policy)
        return after_rollout_query(env, policy, rollout)

    elif ans == "p":
        plot_observations_actions_rewards(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pa":
        plot_actions(rollout, env)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "po":
        plot_observations(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif "po" in ans and any(char.isdigit() for char in ans):
        idcs = [int(s) for s in ans.split() if s.isdigit()]
        plot_observations(rollout, idcs_sel=idcs)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "ps":
        plot_states(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif "ps" in ans and any(char.isdigit() for char in ans):
        idcs = [int(s) for s in ans.split() if s.isdigit()]
        plot_states(rollout, idcs_sel=idcs)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pf":
        plot_features(rollout, policy)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pp":
        draw_policy_params(policy, env.spec, annotate=False)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pr":
        plot_rewards(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "pdt":
        draw_dts(rollout.dts_policy, rollout.dts_step, rollout.dts_remainder)
        plt.show()
        return (after_rollout_query(env, policy, rollout),)

    elif ans == "ppot":
        plot_potentials(rollout)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == "s":
        if isinstance(env, SimEnv):
            dp = env.get_nominal_domain_param()
            for k, v in dp.items():
                dp[k] = [v]  # cast float to list of one element to make it iterable for tabulate
            print("These are the nominal domain parameters:")
            print(tabulate(dp, headers="keys", tablefmt="simple"))

        # Get the user input
        strs = input("Enter one new domain parameter\n(format: key whitespace value):\n")
        try:
            param = dict(str.split() for str in strs.splitlines())
            # Cast the values of the param dict from str to float
            for k, v in param.items():
                param[k] = float(v)
            return False, None, param
        except (ValueError, KeyError):
            print_cbt(f"Could not parse {strs} into a dict.", "r")
            after_rollout_query(env, policy, rollout)

    elif ans == "e":
        env.close()
        return True, None, {}  # breaks the outer while loop

    else:
        return after_rollout_query(env, policy, rollout)  # recursion
コード例 #5
0
ファイル: sb_p3l.py プロジェクト: fdamken/SimuRLacra
def create_adn_setup(dt,
                     max_steps,
                     max_dist_force,
                     physics_engine,
                     normalize_obs=True,
                     obsnorm_cpp=True):
    pyrado.set_seed(0)

    # Explicit normalization bounds
    elb = {
        "EffectorLoadCell_Fx": -100.0,
        "EffectorLoadCell_Fz": -100.0,
        "Effector_Xd": -1,
        "Effector_Zd": -1,
        "GD_DS0d": -1,
        "GD_DS1d": -1,
        "GD_DS2d": -1,
    }
    eub = {
        "GD_DS0": 3.0,
        "GD_DS1": 3,
        "GD_DS2": 3,
        "EffectorLoadCell_Fx": 100.0,
        "EffectorLoadCell_Fz": 100.0,
        "Effector_Xd": 0.5,
        "Effector_Zd": 0.5,
        "GD_DS0d": 0.5,
        "GD_DS1d": 0.5,
        "GD_DS2d": 0.5,
        "PredCollCost_h50": 1000.0,
    }

    extra_kwargs = {}
    if normalize_obs and obsnorm_cpp:
        extra_kwargs["normalizeObservations"] = True
        extra_kwargs["obsNormOverrideLower"] = elb
        extra_kwargs["obsNormOverrideUpper"] = eub

    # Set up environment
    env = Planar3LinkTASim(
        physicsEngine=physics_engine,
        dt=dt,
        max_steps=max_steps,
        max_dist_force=max_dist_force,
        positionTasks=True,
        collisionAvoidanceIK=True,
        taskCombinationMethod="sum",
        observeTaskSpaceDiscrepancy=True,
        **extra_kwargs,
    )

    if normalize_obs and not obsnorm_cpp:
        env = ObsNormWrapper(env, explicit_lb=elb, explicit_ub=eub)

    # Set up random policy
    policy_hparam = dict(
        tau_init=10.0,
        activation_nonlin=to.sigmoid,
        potentials_dyn_fcn=pd_cubic,
    )
    policy = ADNPolicy(spec=env.spec, **policy_hparam)
    print_cbt("Running ADNPolicy with random initialization", "c", bright=True)

    # Simulate and plot potentials
    ro = rollout(env,
                 policy,
                 render_mode=RenderMode(video=True),
                 stop_on_done=True)
    plot_potentials(ro)

    return ro
コード例 #6
0
def after_rollout_query(env: Env, policy: Policy, rollout: StepSequence) -> tuple:
    """
    Ask the user what to do after a rollout has been animated.

    :param env: environment used for the rollout
    :param policy: policy used for the rollout
    :param rollout: collected data from the rollout
    :return: done flag, initial state, and domain parameters
    """
    # Fist entry contains hotkey, second the info text
    options = [
        ['C', 'continue simulation (with domain randomization)'],
        ['N', 'set domain parameters to nominal values and continue'],
        ['F', 'fix the initial state'],
        ['S', 'set a domain parameter explicitly'],
        ['P', 'plot all observations, actions, and rewards'],
        ['PA', 'plot actions'],
        ['PR', 'plot rewards'],
        ['PO', 'plot all observations'],
        ['PO idcs', 'plot selected observations (integers separated by whitespaces)'],
        ['PF', 'plot features (for linear policy)'],
        ['PP', 'plot policy parameters (not suggested for many parameters)'],
        ['PDT', 'plot time deltas (profiling of a real system)'],
        ['PPOT', 'plot potentials, stimuli, and actions'],
        ['E', 'exit']
    ]

    # Ask for user input
    ans = input(tabulate(options, tablefmt='simple') + '\n').lower()

    if ans == 'c' or ans == '':
        # We don't have to do anything here since the env will be reset at the beginning of the next rollout
        return False, None, None

    elif ans == 'n':
        # Get nominal domain parameters
        if isinstance(inner_env(env), SimEnv):
            dp_nom = inner_env(env).get_nominal_domain_param()
            if typed_env(env, ActDelayWrapper) is not None:
                # There is an ActDelayWrapper in the env chain
                dp_nom['act_delay'] = 0
        else:
            dp_nom = None
        return False, None, dp_nom

    elif ans == 'p':
        plot_observations_actions_rewards(rollout)
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pa':
        plot_actions(rollout, env)
        return after_rollout_query(env, policy, rollout)

    elif ans == 'po':
        plot_observations(rollout)
        return after_rollout_query(env, policy, rollout)

    elif 'po' in ans and any(char.isdigit() for char in ans):
        idcs = [int(s) for s in ans.split() if s.isdigit()]
        plot_observations(rollout, idcs_sel=idcs)
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pf':
        plot_features(rollout, policy)
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pp':
        from matplotlib import pyplot as plt
        render_policy_params(policy, env.spec, annotate=False)
        plt.show()
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pr':
        plot_rewards(rollout)
        return after_rollout_query(env, policy, rollout)

    elif ans == 'pdt':
        plot_dts(rollout.dts_policy, rollout.dts_step, rollout.dts_remainder)
        return after_rollout_query(env, policy, rollout),

    elif ans == 'ppot':
        plot_potentials(rollout)
        return after_rollout_query(env, policy, rollout)

    elif ans == 's':
        if isinstance(env, SimEnv):
            dp = env.get_nominal_domain_param()
            for k, v in dp.items():
                dp[k] = [v]  # cast float to list of one element to make it iterable for tabulate
            print('These are the nominal domain parameters:')
            print(tabulate(dp, headers="keys", tablefmt='simple'))

        # Get the user input
        strs = input('Enter one new domain parameter\n(format: key whitespace value):\n')
        try:
            param = dict(str.split() for str in strs.splitlines())
            # Cast the values of the param dict from str to float
            for k, v in param.items():
                param[k] = float(v)
            return False, None, param
        except (ValueError, KeyError):
            print_cbt(f'Could not parse {strs} into a dict.', 'r')
            after_rollout_query(env, policy, rollout)

    elif ans == 'f':
        try:
            if isinstance(inner_env(env), RealEnv):
                raise pyrado.TypeErr(given=inner_env(env), expected_type=SimEnv)
            elif isinstance(inner_env(env), SimEnv):
                # Get the user input
                str = input(f'Enter the {env.obs_space.flat_dim}-dim initial state'
                            f'(format: each dim separated by a whitespace):\n')
                state = list(map(float, str.split()))
                if isinstance(state, list):
                    state = np.array(state)
                    if state.shape != env.obs_space.shape:
                        raise pyrado.ShapeErr(given=state, expected_match=env.obs_space)
                else:
                    raise pyrado.TypeErr(given=state, expected_type=list)
                return False, state, {}
        except (pyrado.TypeErr, pyrado.ShapeErr):
            return after_rollout_query(env, policy, rollout)

    elif ans == 'e':
        env.close()
        return True, None, {}  # breaks the outer while loop

    else:
        return after_rollout_query(env, policy, rollout)  # recursion