Beispiel #1
0
def launch(args):
    env = make_env(args.env_name, env_id=args.env_id,
                   discrete=True, reward_type=args.reward_type)
    # set random seeds for reproducibility
    env.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    random.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    np.random.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    torch.manual_seed(args.seed + MPI.COMM_WORLD.Get_rank())
    if args.deterministic:
        env.make_deterministic()
    if args.debug:
        logger.set_level(DEBUG)
    controller = get_controller(
        args.env_name, env_id=args.env_id, discrete=True, num_expansions=args.num_expansions, reward_type=args.reward_type)

    # Configure logger
    if MPI.COMM_WORLD.Get_rank() == 0 and args.log_dir:
        logger.configure(dir=os.path.join(
            'logs', 'rts', args.log_dir), format_strs=['tensorboard', 'log', 'csv', 'json', 'stdout'])
    args.log_dir = logger.get_dir()
    assert args.log_dir is not None
    os.makedirs(args.log_dir, exist_ok=True)

    env_params = get_env_params(env)

    rts_trainer = dqn_rts_agent(args, env, env_params, controller)
    rts_trainer.learn()
Beispiel #2
0
def launch(args):
    assert not args.env_name.startswith(
        'Residual'), 'Residual envs not allowed'
    # create the ddpg_agent
    env = make_env(args.env_name)
    controller = get_controller(args.env_name)
    # set random seeds for reproducibility
    env.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    random.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    np.random.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    torch.manual_seed(args.seed + MPI.COMM_WORLD.Get_rank())
    if args.cuda:
        torch.cuda.manual_seed(args.seed + MPI.COMM_WORLD.Get_rank())
    # Configure logger
    if MPI.COMM_WORLD.Get_rank() == 0:
        if args.log_dir or logger.get_dir() is None:
            logger.configure(
                dir=os.path.join('logs', 'switch_her', args.log_dir),
                format_strs=['tensorboard', 'log', 'csv', 'json', 'stdout'])
        else:
            logger.configure(
                dir=os.path.join('logs', 'switch_her', args.env_name),
                format_strs=['tensorboard', 'log', 'csv', 'json', 'stdout'])
    args.log_dir = logger.get_dir()
    assert args.log_dir is not None
    os.makedirs(args.log_dir, exist_ok=True)
    # TODO: Write code for loading and saving params from/to json files
    # get the environment parameters
    env_params = get_env_params(env)
    # create the ddpg agent to interact with the environment
    her_trainer = her_switch_agent(args, env, env_params, controller)
    her_trainer.learn()
Beispiel #3
0
 def __init__(self, args, env_params):
     self.controller = get_controller(args.env_name, env_id=args.env_id, discrete=True,
                                      num_expansions=args.offline_num_expansions, reward_type=args.reward_type)
     self.residual = Residual(env_params)
     self.env = make_env(args.env_name, args.env_id,
                         discrete=True, reward_type=args.reward_type)
     self.f_norm = normalizer(env_params['num_features'])
     self.dummy_sim_state = self.env.reset()['sim_state']
Beispiel #4
0
 def __init__(self, args, env_params, worker_id=0):
     # Save args
     self.args, self.env_params = args, env_params
     # Env
     self.env_id = args.planning_env_id
     self.env = make_env(env_name=args.env_name,
                         env_id=self.env_id,
                         discrete=True,
                         reward_type=args.reward_type)
     # Set environment seed
     self.env.seed(args.seed + worker_id)
     # Make deterministic, if you have to
     if args.deterministic:
         self.env.make_deterministic()
     # Controller
     self.controller = get_controller(
         env_name=args.env_name,
         num_expansions=args.n_expansions,
         # NOTE: Controller can only use internal model
         env_id=args.planning_env_id,
         discrete=True,
         reward_type=args.reward_type,
         seed=args.seed + worker_id)
     # State value residual
     self.state_value_residual = StateValueResidual(env_params)
     # KDTrees
     self.kdtrees = [None for _ in range(self.env_params['num_actions'])]
     # Normalizers
     self.features_normalizer = FeatureNormalizer(env_params)
     # Dynamics model
     if self.args.agent == 'mbpo':
         self.residual_dynamics = DynamicsResidual(env_params)
     elif self.args.agent == 'mbpo_knn':
         self.residual_dynamics = [
             KNNDynamicsResidual(args, env_params)
             for _ in range(self.env_params['num_actions'])
         ]
     else:
         self.residual_dynamics = [
             GPDynamicsResidual(args, env_params)
             for _ in range(self.env_params['num_actions'])
         ]
         # Flags
     self.kdtrees_set = False
     self.residual_dynamics_set = False
Beispiel #5
0
def launch(args):
    # rospy.init_node('rts_trainer', anonymous=True)
    # Start ray
    ray.init(logging_level=logging.ERROR)
    # Create environments
    env = make_env(env_name=args.env_name,
                   env_id=args.env_id,
                   discrete=True,
                   reward_type=args.reward_type)
    planning_env = make_env(env_name=args.env_name,
                            env_id=args.planning_env_id,
                            discrete=True,
                            reward_type=args.reward_type)
    # Set random seeds
    env.seed(args.seed)
    planning_env.seed(args.seed)
    # Set global seeds
    set_global_seed(args.seed)
    # Make deterministic, if you have to
    if args.deterministic:
        env.make_deterministic()
        planning_env.make_deterministic()
    # Set logger level to debug, if you have to
    if args.debug:
        logger.set_level(logger.DEBUG)
    # Create controller
    controller = get_controller(env_name=args.env_name,
                                env_id=args.planning_env_id,
                                discrete=True,
                                num_expansions=args.n_expansions,
                                reward_type=args.reward_type,
                                seed=args.seed)
    # Configure logger
    if args.log_dir:
        logger.configure(
            dir=osp.join('logs', 'rts', args.log_dir),
            format_strs=['tensorboard', 'log', 'csv', 'json', 'stdout'])
    os.makedirs(logger.get_dir(), exist_ok=True)

    # Configure save dir
    # if args.save_dir:
    #     args.save_dir = osp.join('saved', 'rts', args.save_dir)
    #     os.makedirs(args.save_dir, exist_ok=True)

    # if args.load_dir:
    #     args.load_dir = osp.join('saved', 'rts', args.load_dir)
    #     # TODO: CHeck if dir exists

    # Get env params
    env_params = get_env_params(args, env)
    # Get agent
    if args.agent == 'rts' or args.agent == 'mbpo' or args.agent == 'mbpo_knn' or args.agent == 'mbpo_gp':
        fetch_trainer = fetch_rts_agent(args, env_params, env, planning_env,
                                        controller)
    elif args.agent == 'dqn':
        fetch_trainer = fetch_dqn_agent(args, env_params, env, controller)
    # elif args.agent == 'mbpo':
    #     fetch_trainer = fetch_model_agent(args,
    #                                       env_params,
    #                                       env,
    #                                       planning_env,
    #                                       controller)
    # Start
    if args.offline:
        # Train in simulation
        raise Exception('Only online mode is required')
        fetch_trainer.learn_offline_in_model()
    else:
        n_steps = fetch_trainer.learn_online_in_real_world(args.max_timesteps)
        print('REACHED GOAL in', n_steps, 'by agent', args.agent)
        ray.shutdown()
        time.sleep(5)
        return n_steps
Beispiel #6
0
parser.add_argument('--discrete', action='store_true')
parser.add_argument('--env-id',
                    type=int,
                    default=None,
                    help='Env id for the FetchPushAmongObstacles env')
parser.add_argument('--deterministic', action='store_true')
args = parser.parse_args()

env = make_env(args.env_name, env_id=args.env_id, discrete=args.discrete)
env.seed(args.seed)

if args.deterministic:
    env.make_deterministic()

controller = get_controller(args.env_name,
                            num_expansions=args.num_expansions,
                            env_id=args.env_id,
                            discrete=args.discrete)

obs = env.reset()
t = 0
num_successes = 0.
num_episodes = 0.

f_vals_best = []
f_vals_current_episode_best = []
f_vals_start = []
f_vals_current_episode_start = []

while num_episodes < args.total_num_episodes:
    ac, info = controller.act(obs)
    print(ac, info)