コード例 #1
0
 def rollout(self, env, policy, max_path_length):
     goal = env.sample_goal_for_rollout()
     return multitask_rollout(
         env,
         agent=policy,
         goal=goal,
         discount=self.max_tau,
         max_path_length=max_path_length,
         decrement_discount=self.cycle_taus_for_rollout,
         cycle_tau=self.cycle_taus_for_rollout,
     )
コード例 #2
0
def plot_performance(policy, env, nrolls):
    print("max_tau, distance")
    # fixed_goals = [-40, -30, 30, 40]
    fixed_goals = [-5, -3, 3, 5]
    taus = np.arange(10) * 10
    for row, fix_tau in enumerate([True, False]):
        for col, horizon_fixed in enumerate([True, False]):
            plot_num = row + 2 * col + 1
            plt.subplot(2, 2, plot_num)
            for fixed_goal in fixed_goals:
                distances = []
                for max_tau in taus:
                    paths = []
                    for _ in range(nrolls):
                        goal = env.sample_goal_for_rollout()
                        goal[0] = fixed_goal
                        path = multitask_rollout(
                            env,
                            policy,
                            goal,
                            init_tau=max_tau,
                            max_path_length=100 if horizon_fixed else max_tau +
                            1,
                            animated=False,
                            cycle_tau=True,
                            decrement_tau=not fix_tau,
                        )
                        paths.append(path)
                    env.log_diagnostics(paths)
                    for key, value in get_generic_path_information(
                            paths).items():
                        logger.record_tabular(key, value)
                    distance = float(
                        dict(logger._tabular)['Final Distance to goal Mean'])
                    distances.append(distance)

                plt.plot(taus, distances)
                print("line done")
            plt.legend([str(goal) for goal in fixed_goals])
            if fix_tau:
                plt.xlabel("Tau (Horizon-1)")
            else:
                plt.xlabel("Initial tau (=Horizon-1)")
            plt.xlabel("Max tau")
            plt.ylabel("Final distance to goal")
            plt.title("Fix Tau = {}, Horizon Fixed to 100  = {}".format(
                fix_tau,
                horizon_fixed,
            ))
    plt.show()
    plt.savefig('results/iclr2018/cheetah-sweep-tau-eval-5-3.jpg')
コード例 #3
0
 def test_decrement_tau(self):
     env = StubMultitaskEnv()
     policy = StubUniversalPolicy()
     goal = None
     tau = 10
     path = multitask_rollout(
         env,
         policy,
         goal,
         tau,
         max_path_length=tau,
         animated=False,
         decrement_tau=True,
     )
     self.assertTrue(np.all(path['terminals'] == False))
     self.assertTrue(len(path['terminals']) == tau)
コード例 #4
0
 def test_multitask_rollout_length(self):
     env = StubMultitaskEnv()
     policy = StubUniversalPolicy()
     goal = None
     discount = 1
     path = multitask_rollout(
         env,
         policy,
         goal,
         discount,
         max_path_length=100,
         animated=False,
         decrement_tau=False,
     )
     self.assertTrue(np.all(path['terminals'] == False))
     self.assertTrue(len(path['terminals']) == 100)
コード例 #5
0
 def test_decrement_tau(self):
     env = StubMultitaskEnv()
     policy = StubUniversalPolicy()
     goal = None
     tau = 5
     path = multitask_rollout(
         env,
         policy,
         goal,
         tau,
         max_path_length=10,
         animated=False,
         decrement_tau=True,
         cycle_tau=False,
     )
     self.assertEqual(
         list(path['num_steps_left']),
         [5, 4, 3, 2, 1, 0, 0, 0, 0, 0]
     )
コード例 #6
0
         #     original_policy,
         #     max_path_length=args.H,
         #     animated=not args.hide,
         # )
         # goal = np.array([1.4952445864440109, 0.058365245652776926,
         #                  1.3854542196239863, -0.64643021271356582, 0.25729402753586905, -1.0559116816553138, -1.2942449012062724, 0.84327192781565719, -0.18665817808605106, 0.28887389778176836, -4.1567137920511996, -0.25677653709657877, 1.2789295463658288, 0.47291580030348057, 0.34130661157042974, 0.13003414588968379, -0.009319281912785882])
         # goal = np.array([1.6958471372903317, 0.2122816058111654,
         #                  0.29760944600589051, -0.016908392188567031, -0.58501650189613841, -0.018928029822669078, -1.2091424324357098, 0.16575094693524303, 0.32991058173255483, 2.8226738796936663, -0.57674228567507868, 1.5591211986667852, 0.53321884401877584, -3.8082528691546091, -0.11086735631355096, 0.29765427337121497, -0.16364599916575717])
         goal = env.sample_goal_for_rollout()
         goal[7:14] = 0
         path = multitask_rollout(
             env,
             original_policy,
             # env.multitask_goal,
             goal,
             init_tau=10,
             max_path_length=args.H,
             animated=not args.hide,
             cycle_tau=True,
             decrement_tau=False,
         )
         if hasattr(env, "log_diagnostics"):
             env.log_diagnostics([path])
         logger.dump_tabular()
 else:
     for weight in [1]:
         for num_simulated_paths in [args.npath]:
             print("")
             print("weight", weight)
             print("num_simulated_paths", num_simulated_paths)
             policy = CollocationMpcController(
from railrl.state_distance.rollout_util import multitask_rollout

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('file', type=str,
                        help='path to the snapshot file')
    parser.add_argument('--pause', action='store_true')
    args = parser.parse_args()
    if args.pause:
        import ipdb; ipdb.set_trace()

    data = joblib.load(args.file)
    env = data['env']
    qf = data['qf']
    policy = data['policy']
    tdm_policy = data['trained_policy']
    random_policy = RandomPolicy(env.action_space)
    vf = data['vf']
    path = multitask_rollout(
        env,
        random_policy,
        init_tau=0,
        max_path_length=100,
        animated=True,
    )
    goal = env.sample_goal_for_rollout()

    import ipdb; ipdb.set_trace()
    agent_infos = path['agent_infos']
コード例 #8
0
        # some environments need to be reconfigured for visualization
        env.enable_render()
    if args.mode:
        env.mode(args.mode)

    while True:
        paths = []
        for _ in range(args.nrolls):
            if args.silent:
                goal = None
            else:
                goal = env.sample_goal_for_rollout()
            path = multitask_rollout(
                env,
                policy,
                init_tau=max_tau,
                goal=goal,
                max_path_length=args.H,
                # animated=not args.hide,
                cycle_tau=args.cycle or not args.ndc,
                decrement_tau=args.dt or not args.ndc,
                env_samples_goal_on_reset=args.silent,
                # get_action_kwargs={'deterministic': True},
            )
            print("last state", path['next_observations'][-1][21:24])
            paths.append(path)
        env.log_diagnostics(paths)
        for key, value in get_generic_path_information(paths).items():
            logger.record_tabular(key, value)
        logger.dump_tabular()