コード例 #1
0
def get_task_dim(args):
    env = make_vec_envs(env_name=args.env_name, seed=args.seed, num_processes=args.num_processes,
                        gamma=args.policy_gamma, device=device,
                        episodes_per_task=args.max_rollouts_per_task,
                        normalise_rew=args.norm_rew_for_policy, ret_rms=None,
                        tasks=None
                        )
    return env.task_dim
コード例 #2
0
def get_num_tasks(args):
    env = make_vec_envs(env_name=args.env_name, seed=args.seed, num_processes=args.num_processes,
                        gamma=args.policy_gamma, device=device,
                        episodes_per_task=args.max_rollouts_per_task,
                        normalise_rew=args.norm_rew_for_policy, ret_rms=None,
                        tasks=None
                        )
    try:
        num_tasks = env.num_tasks
    except AttributeError:
        num_tasks = None
    return num_tasks
コード例 #3
0
ファイル: metalearner.py プロジェクト: orilinial/RSPP
    def __init__(self, args):

        self.args = args
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # calculate number of updates and keep count of frames/iterations
        self.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes
        self.frames = 0
        self.iter_idx = 0

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            device=device,
            episodes_per_task=self.args.max_rollouts_per_task,
            normalise_rew=args.norm_rew_for_policy,
            ret_rms=None,
        )

        # calculate what the maximum length of the trajectories is
        self.args.max_trajectory_len = self.envs._max_episode_steps
        self.args.max_trajectory_len *= self.args.max_rollouts_per_task

        # get policy input dimensions
        self.args.state_dim = self.envs.observation_space.shape[0]
        self.args.task_dim = self.envs.task_dim
        self.args.belief_dim = self.envs.belief_dim
        self.args.num_states = self.envs.num_states
        # get policy output (action) dimensions
        self.args.action_space = self.envs.action_space
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        elif isinstance(self.envs.action_space,
                        gym.spaces.multi_discrete.MultiDiscrete):
            self.args.action_dim = self.envs.action_space.nvec[0]
        else:
            self.args.action_dim = self.envs.action_space.shape[0]

        # initialise VAE and policy
        self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx)
        self.policy_storage = self.initialise_policy_storage()
        self.policy = self.initialise_policy()
コード例 #4
0
ファイル: metalearner.py プロジェクト: jeonggwanlee/varibad
    def __init__(self, args):
        self.args = args
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # count number of frames and number of meta-iterations
        self.frames = 0
        self.iter_idx = 0

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            log_dir=args.agent_log_dir,
            device=device,
            allow_early_resets=False,
            episodes_per_task=self.args.max_rollouts_per_task,
            obs_rms=None,
            ret_rms=None,
        )

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = self.envs._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task

        # calculate number of meta updates
        self.args.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes

        # get action / observation dimensions
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.envs.action_space.shape[0]
        self.args.obs_dim = self.envs.observation_space.shape[0]
        self.args.num_states = self.envs.num_states if str.startswith(
            self.args.env_name, 'Grid') else None
        self.args.act_space = self.envs.action_space

        self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx)

        self.initialise_policy()
コード例 #5
0
ファイル: main.py プロジェクト: lmzintgraf/varibad
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env-type', default='gridworld_varibad')
    args, rest_args = parser.parse_known_args()
    env = args.env_type

    # --- GridWorld ---

    if env == 'gridworld_belief_oracle':
        args = args_grid_belief_oracle.get_args(rest_args)
    elif env == 'gridworld_varibad':
        args = args_grid_varibad.get_args(rest_args)
    elif env == 'gridworld_rl2':
        args = args_grid_rl2.get_args(rest_args)

    # --- PointRobot 2D Navigation ---

    elif env == 'pointrobot_multitask':
        args = args_pointrobot_multitask.get_args(rest_args)
    elif env == 'pointrobot_varibad':
        args = args_pointrobot_varibad.get_args(rest_args)
    elif env == 'pointrobot_rl2':
        args = args_pointrobot_rl2.get_args(rest_args)
    elif env == 'pointrobot_humplik':
        args = args_pointrobot_humplik.get_args(rest_args)

    # --- MUJOCO ---

    # - CheetahDir -
    elif env == 'cheetah_dir_multitask':
        args = args_cheetah_dir_multitask.get_args(rest_args)
    elif env == 'cheetah_dir_expert':
        args = args_cheetah_dir_expert.get_args(rest_args)
    elif env == 'cheetah_dir_varibad':
        args = args_cheetah_dir_varibad.get_args(rest_args)
    elif env == 'cheetah_dir_rl2':
        args = args_cheetah_dir_rl2.get_args(rest_args)
    #
    # - CheetahVel -
    elif env == 'cheetah_vel_multitask':
        args = args_cheetah_vel_multitask.get_args(rest_args)
    elif env == 'cheetah_vel_expert':
        args = args_cheetah_vel_expert.get_args(rest_args)
    elif env == 'cheetah_vel_avg':
        args = args_cheetah_vel_avg.get_args(rest_args)
    elif env == 'cheetah_vel_varibad':
        args = args_cheetah_vel_varibad.get_args(rest_args)
    elif env == 'cheetah_vel_rl2':
        args = args_cheetah_vel_rl2.get_args(rest_args)
    #
    # - AntDir -
    elif env == 'ant_dir_multitask':
        args = args_ant_dir_multitask.get_args(rest_args)
    elif env == 'ant_dir_expert':
        args = args_ant_dir_expert.get_args(rest_args)
    elif env == 'ant_dir_varibad':
        args = args_ant_dir_varibad.get_args(rest_args)
    elif env == 'ant_dir_rl2':
        args = args_ant_dir_rl2.get_args(rest_args)
    #
    # - AntGoal -
    elif env == 'ant_goal_multitask':
        args = args_ant_goal_multitask.get_args(rest_args)
    elif env == 'ant_goal_expert':
        args = args_ant_goal_expert.get_args(rest_args)
    elif env == 'ant_goal_varibad':
        args = args_ant_goal_varibad.get_args(rest_args)
    elif env == 'ant_goal_humplik':
        args = args_ant_goal_humplik.get_args(rest_args)
    elif env == 'ant_goal_rl2':
        args = args_ant_goal_rl2.get_args(rest_args)
    #
    # - Walker -
    elif env == 'walker_multitask':
        args = args_walker_multitask.get_args(rest_args)
    elif env == 'walker_expert':
        args = args_walker_expert.get_args(rest_args)
    elif env == 'walker_avg':
        args = args_walker_avg.get_args(rest_args)
    elif env == 'walker_varibad':
        args = args_walker_varibad.get_args(rest_args)
    elif env == 'walker_rl2':
        args = args_walker_rl2.get_args(rest_args)
    #
    # - HumanoidDir -
    elif env == 'humanoid_dir_multitask':
        args = args_humanoid_dir_multitask.get_args(rest_args)
    elif env == 'humanoid_dir_expert':
        args = args_humanoid_dir_expert.get_args(rest_args)
    elif env == 'humanoid_dir_varibad':
        args = args_humanoid_dir_varibad.get_args(rest_args)
    elif env == 'humanoid_dir_rl2':
        args = args_humanoid_dir_rl2.get_args(rest_args)
    else:
        raise Exception("Invalid Environment")

    # warning for deterministic execution
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError(
                'If you want fully deterministic code, run it with num_processes=1.'
                'Warning: This will slow things down and might break A2C if '
                'policy_num_steps < env._max_episode_steps.')

    # if we're normalising the actions, we have to make sure that the env expects actions within [-1, 1]
    if args.norm_actions_pre_sampling or args.norm_actions_post_sampling:
        envs = make_vec_envs(
            env_name=args.env_name,
            seed=0,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            device='cpu',
            episodes_per_task=args.max_rollouts_per_task,
            normalise_rew=args.norm_rew_for_policy,
            ret_rms=None,
            tasks=None,
        )
        assert np.unique(envs.action_space.low) == [-1]
        assert np.unique(envs.action_space.high) == [1]

    # clean up arguments
    if args.disable_metalearner or args.disable_decoder:
        args.decode_reward = False
        args.decode_state = False
        args.decode_task = False

    if hasattr(args, 'decode_only_past') and args.decode_only_past:
        args.split_batches_by_elbo = True
    # if hasattr(args, 'vae_subsample_decodes') and args.vae_subsample_decodes:
    #     args.split_batches_by_elbo = True

    # begin training (loop through all passed seeds)
    seed_list = [args.seed] if isinstance(args.seed, int) else args.seed
    for seed in seed_list:
        print('training', seed)
        args.seed = seed
        args.action_space = None

        if args.disable_metalearner:
            # If `disable_metalearner` is true, the file `learner.py` will be used instead of `metalearner.py`.
            # This is a stripped down version without encoder, decoder, stochastic latent variables, etc.
            learner = Learner(args)
        else:
            learner = MetaLearner(args)
        learner.train()
コード例 #6
0
ファイル: metalearner.py プロジェクト: jeonggwanlee/varibad
    def load_and_render(self, load_iter):
        #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahJoint-v0/varibad_73__15:05_17:14:07', 'models')
        #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/hfield', 'models')
        save_path = os.path.join(
            '/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahBlocks-v0/varibad_73__15:05_20:20:25',
            'models')
        self.policy.actor_critic = torch.load(
            os.path.join(save_path, "policy{0}.pt".format(load_iter)))
        self.vae.encoder = torch.load(
            os.path.join(save_path, "encoder{0}.pt").format(load_iter))

        args = self.args
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        num_processes = 1
        num_episodes = 100
        num_steps = 1999

        #import pdb; pdb.set_trace()
        # initialise environments
        envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=num_processes,  # 1
            gamma=args.policy_gamma,
            log_dir=args.agent_log_dir,
            device=device,
            allow_early_resets=False,
            episodes_per_task=self.args.max_rollouts_per_task,
            obs_rms=None,
            ret_rms=None,
        )

        # reset latent state to prior
        latent_sample, latent_mean, latent_logvar, hidden_state = self.vae.encoder.prior(
            num_processes)

        for episode_idx in range(num_episodes):
            (prev_obs_raw, prev_obs_normalised) = envs.reset()
            prev_obs_raw = prev_obs_raw.to(device)
            prev_obs_normalised = prev_obs_normalised.to(device)
            for step_idx in range(num_steps):

                with torch.no_grad():
                    _, action, _ = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar,
                        deterministic=True)

                # observe reward and next obs
                (next_obs_raw, next_obs_normalised), (
                    rew_raw,
                    rew_normalised), done, infos = utl.env_step(envs, action)
                # render
                envs.venv.venv.envs[0].env.env.env.env.render()

                # update the hidden state
                latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(
                    encoder=self.vae.encoder,
                    next_obs=next_obs_raw,
                    action=action,
                    reward=rew_raw,
                    done=None,
                    hidden_state=hidden_state)

                prev_obs_normalised = next_obs_normalised
                prev_obs_raw = next_obs_raw

                if done[0]:
                    break
コード例 #7
0
ファイル: evaluation.py プロジェクト: lmzintgraf/varibad
def visualise_behaviour(
    args,
    policy,
    image_folder,
    iter_idx,
    ret_rms,
    tasks,
    encoder=None,
    reward_decoder=None,
    state_decoder=None,
    task_decoder=None,
    compute_rew_reconstruction_loss=None,
    compute_task_reconstruction_loss=None,
    compute_state_reconstruction_loss=None,
    compute_kl_loss=None,
):
    # initialise environment
    env = make_vec_envs(
        env_name=args.env_name,
        seed=args.seed * 42 + iter_idx,
        num_processes=1,
        gamma=args.policy_gamma,
        device=device,
        episodes_per_task=args.max_rollouts_per_task,
        normalise_rew=args.norm_rew_for_policy,
        ret_rms=ret_rms,
        rank_offset=args.num_processes +
        42,  # not sure if the temp folders would otherwise clash
        tasks=tasks)
    episode_task = torch.from_numpy(np.array(
        env.get_task())).to(device).float()

    # get a sample rollout
    unwrapped_env = env.venv.unwrapped.envs[0]
    if hasattr(env.venv.unwrapped.envs[0], 'unwrapped'):
        unwrapped_env = unwrapped_env.unwrapped
    if hasattr(unwrapped_env, 'visualise_behaviour'):
        # if possible, get it from the env directly
        # (this might visualise other things in addition)
        traj = unwrapped_env.visualise_behaviour(
            env=env,
            args=args,
            policy=policy,
            iter_idx=iter_idx,
            encoder=encoder,
            reward_decoder=reward_decoder,
            state_decoder=state_decoder,
            task_decoder=task_decoder,
            image_folder=image_folder,
        )
    else:
        traj = get_test_rollout(args, env, policy, encoder)

    latent_means, latent_logvars, episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, episode_returns = traj

    if latent_means is not None:
        plot_latents(latent_means,
                     latent_logvars,
                     image_folder=image_folder,
                     iter_idx=iter_idx)

        if not (args.disable_decoder and args.disable_kl_term):
            plot_vae_loss(
                args,
                latent_means,
                latent_logvars,
                episode_prev_obs,
                episode_next_obs,
                episode_actions,
                episode_rewards,
                episode_task,
                image_folder=image_folder,
                iter_idx=iter_idx,
                reward_decoder=reward_decoder,
                state_decoder=state_decoder,
                task_decoder=task_decoder,
                compute_task_reconstruction_loss=
                compute_task_reconstruction_loss,
                compute_rew_reconstruction_loss=compute_rew_reconstruction_loss,
                compute_state_reconstruction_loss=
                compute_state_reconstruction_loss,
                compute_kl_loss=compute_kl_loss,
            )

    env.close()
コード例 #8
0
ファイル: evaluation.py プロジェクト: lmzintgraf/varibad
def evaluate(args,
             policy,
             ret_rms,
             iter_idx,
             tasks,
             encoder=None,
             num_episodes=None):
    env_name = args.env_name
    if hasattr(args, 'test_env_name'):
        env_name = args.test_env_name
    if num_episodes is None:
        num_episodes = args.max_rollouts_per_task
    num_processes = args.num_processes

    # --- set up the things we want to log ---

    # for each process, we log the returns during the first, second, ... episode
    # (such that we have a minimum of [num_episodes]; the last column is for
    #  any overflow and will be discarded at the end, because we need to wait until
    #  all processes have at least [num_episodes] many episodes)
    returns_per_episode = torch.zeros(
        (num_processes, num_episodes + 1)).to(device)

    # --- initialise environments and latents ---

    envs = make_vec_envs(
        env_name,
        seed=args.seed * 42 + iter_idx,
        num_processes=num_processes,
        gamma=args.policy_gamma,
        device=device,
        rank_offset=num_processes +
        1,  # to use diff tmp folders than main processes
        episodes_per_task=num_episodes,
        normalise_rew=args.norm_rew_for_policy,
        ret_rms=ret_rms,
        tasks=tasks,
        add_done_info=args.max_rollouts_per_task > 1,
    )
    num_steps = envs._max_episode_steps

    # reset environments
    state, belief, task = utl.reset_env(envs, args)

    # this counts how often an agent has done the same task already
    task_count = torch.zeros(num_processes).long().to(device)

    if encoder is not None:
        # reset latent state to prior
        latent_sample, latent_mean, latent_logvar, hidden_state = encoder.prior(
            num_processes)
    else:
        latent_sample = latent_mean = latent_logvar = hidden_state = None

    for episode_idx in range(num_episodes):

        for step_idx in range(num_steps):

            with torch.no_grad():
                _, action = utl.select_action(args=args,
                                              policy=policy,
                                              state=state,
                                              belief=belief,
                                              task=task,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar,
                                              deterministic=True)

            # observe reward and next obs
            [state, belief,
             task], (rew_raw, rew_normalised), done, infos = utl.env_step(
                 envs, action, args)
            done_mdp = [info['done_mdp'] for info in infos]

            if encoder is not None:
                # update the hidden state
                latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(
                    encoder=encoder,
                    next_obs=state,
                    action=action,
                    reward=rew_raw,
                    done=None,
                    hidden_state=hidden_state)

            # add rewards
            returns_per_episode[range(num_processes),
                                task_count] += rew_raw.view(-1)

            for i in np.argwhere(done_mdp).flatten():
                # count task up, but cap at num_episodes + 1
                task_count[i] = min(task_count[i] + 1,
                                    num_episodes)  # zero-indexed, so no +1
            if np.sum(done) > 0:
                done_indices = np.argwhere(done.flatten()).flatten()
                state, belief, task = utl.reset_env(envs,
                                                    args,
                                                    indices=done_indices,
                                                    state=state)

    envs.close()

    return returns_per_episode[:, :num_episodes]
コード例 #9
0
ファイル: learner.py プロジェクト: lmzintgraf/varibad
    def __init__(self, args):

        self.args = args
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # calculate number of updates and keep count of frames/iterations
        self.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes
        self.frames = 0
        self.iter_idx = -1

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            device=device,
            episodes_per_task=self.args.max_rollouts_per_task,
            normalise_rew=args.norm_rew_for_policy,
            ret_rms=None,
            tasks=None)

        if self.args.single_task_mode:
            # get the current tasks (which will be num_process many different tasks)
            self.train_tasks = self.envs.get_task()
            # set the tasks to the first task (i.e. just a random task)
            self.train_tasks[1:] = self.train_tasks[0]
            # make it a list
            self.train_tasks = [t for t in self.train_tasks]
            # re-initialise environments with those tasks
            self.envs = make_vec_envs(
                env_name=args.env_name,
                seed=args.seed,
                num_processes=args.num_processes,
                gamma=args.policy_gamma,
                device=device,
                episodes_per_task=self.args.max_rollouts_per_task,
                normalise_rew=args.norm_rew_for_policy,
                ret_rms=None,
                tasks=self.train_tasks,
            )
            # save the training tasks so we can evaluate on the same envs later
            utl.save_obj(self.train_tasks, self.logger.full_output_folder,
                         "train_tasks")
        else:
            self.train_tasks = None

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = self.envs._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task

        # get policy input dimensions
        self.args.state_dim = self.envs.observation_space.shape[0]
        self.args.task_dim = self.envs.task_dim
        self.args.belief_dim = self.envs.belief_dim
        self.args.num_states = self.envs.num_states
        # get policy output (action) dimensions
        self.args.action_space = self.envs.action_space
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.envs.action_space.shape[0]

        # initialise policy
        self.policy_storage = self.initialise_policy_storage()
        self.policy = self.initialise_policy()