Example #1
0
    def initialise_policy(self):

        # initialise policy network
        policy_net = Policy(
            args=self.args,
            #
            pass_state_to_policy=self.args.pass_state_to_policy,
            pass_latent_to_policy=self.args.pass_latent_to_policy,
            pass_belief_to_policy=self.args.pass_belief_to_policy,
            pass_task_to_policy=self.args.pass_task_to_policy,
            dim_state=self.args.state_dim,
            dim_latent=self.args.latent_dim * 2,
            dim_belief=self.args.belief_dim,
            dim_task=self.args.task_dim,
            #
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            policy_initialisation=self.args.policy_initialisation,
            #
            action_space=self.envs.action_space,
            init_std=self.args.policy_init_std,
        ).to(device)

        # initialise policy trainer
        if self.args.policy == 'a2c':
            policy = A2C(
                self.args,
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                policy_optimiser=self.args.policy_optimiser,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.num_updates,
                optimiser_vae=self.vae.optimiser_vae,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
            )
        elif self.args.policy == 'ppo':
            policy = PPO(
                self.args,
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                policy_optimiser=self.args.policy_optimiser,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.num_updates,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                ppo_epoch=self.args.ppo_num_epochs,
                num_mini_batch=self.args.ppo_num_minibatch,
                use_huber_loss=self.args.ppo_use_huberloss,
                use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
                clip_param=self.args.ppo_clip_param,
                optimiser_vae=self.vae.optimiser_vae,
            )
        else:
            raise NotImplementedError

        return policy
Example #2
0
    def initialise_policy(self):

        if hasattr(self.envs.action_space, 'low'):
            action_low = self.envs.action_space.low
            action_high = self.envs.action_space.high
        else:
            action_low = action_high = None

        # initialise policy network
        policy_net = Policy(
            args=self.args,
            #
            pass_state_to_policy=self.args.pass_state_to_policy,
            pass_latent_to_policy=
            False,  # use metalearner.py if you want to use the VAE
            pass_belief_to_policy=self.args.pass_belief_to_policy,
            pass_task_to_policy=self.args.pass_task_to_policy,
            dim_state=self.args.state_dim,
            dim_latent=0,
            dim_belief=self.args.belief_dim,
            dim_task=self.args.task_dim,
            #
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            policy_initialisation=self.args.policy_initialisation,
            #
            action_space=self.envs.action_space,
            init_std=self.args.policy_init_std,
            norm_actions_of_policy=self.args.norm_actions_of_policy,
            action_low=action_low,
            action_high=action_high,
        ).to(device)

        # initialise policy trainer
        if self.args.policy == 'a2c':
            policy = A2C(
                self.args,
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                policy_optimiser=self.args.policy_optimiser,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.num_updates,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
            )
        elif self.args.policy == 'ppo':
            policy = PPO(
                self.args,
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                policy_optimiser=self.args.policy_optimiser,
                policy_anneal_lr=self.args.policy_anneal_lr,
                train_steps=self.num_updates,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                ppo_epoch=self.args.ppo_num_epochs,
                num_mini_batch=self.args.ppo_num_minibatch,
                use_huber_loss=self.args.ppo_use_huberloss,
                use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
                clip_param=self.args.ppo_clip_param,
            )
        else:
            raise NotImplementedError

        return policy
Example #3
0
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 18 16:35:31 2019

@author: clytie
"""

if __name__ == "__main__":
    import numpy as np
    import time
    from tqdm import tqdm
    from env.dist_env import BreakoutEnv
    from algorithms.ppo import PPO

    ppo = PPO(4, (84, 84, 4), temperature=0.1, save_path="./ppo_log")
    env = BreakoutEnv(4999, num_envs=1, mode="test")
    env_ids, states, _, _ = env.start()
    for _ in tqdm(range(10000)):
        time.sleep(0.1)
        actions = ppo.get_action(np.asarray(states))
        env_ids, states, _, _ = env.step(env_ids, actions)
    env.close()
Example #4
0
        x = tf.layers.conv2d(x, 64, 3, 1, activation=tf.nn.relu)
        x = tf.contrib.layers.flatten(x)
        x = tf.layers.dense(x, 512, activation=tf.nn.relu)

        logit_action_probability = tf.layers.dense(
            x,
            action_space,
            kernel_initializer=tf.truncated_normal_initializer(0.0, 0.01))
        state_value = tf.squeeze(
            tf.layers.dense(
                x, 1, kernel_initializer=tf.truncated_normal_initializer()))
        return logit_action_probability, state_value

    ppo = PPO(action_space,
              obs_fn,
              model_fn,
              train_epoch=5,
              batch_size=64,
              save_path='./ppo_log_oneplayer_stack')

    env = Raiden2(6666, num_envs=8, with_stack=True)
    env_ids, states, rewards, dones = env.start()

    nth_trajectory = 0
    while True:
        nth_trajectory += 1
        for _ in tqdm(range(explore_steps)):
            actions = ppo.get_action(np.asarray(states))
            actions = [(action, 4) for action in actions]
            env_ids, states, rewards, dones = env.step(env_ids, actions)

        s_batchs, a_batchs, r_batchs, d_batchs = env.get_episodes()
Example #5
0
def run(config):
    model_path = (Path('./models') / config.env_id /
                  ('run%i' % config.run_num))
    if config.incremental is not None:
        model_path = model_path / 'incremental' / ('model_ep%i.pt' %
                                                   config.incremental)
    else:
        model_path = model_path / 'model.pt'

    if config.save_gifs:
        gif_path = model_path.parent / 'gifs'
        gif_path.mkdir(exist_ok=True)

    ppo = PPO.init_from_save(model_path)
    env = make_env(config.env_id)
    ppo.prep_rollouts(device='cpu')
    ifi = 1 / config.fps  # inter-frame interval
    for ep_i in range(config.n_episodes):
        print("Episode %i of %i" % (ep_i + 1, config.n_episodes))
        obs = env.reset()
        nagents = len(obs)

        dones = [False] * nagents
        for agent in env.agents:
            agent.trajectory = []

        if config.save_gifs:
            frames = []
            frames.append(env.render('rgb_array')[0])
        act_hidden = [[torch.zeros(1, 128), torch.zeros(1, 128)] for i in range(nagents)]
        crt_hidden = [[torch.zeros(1, 128), torch.zeros(1, 128)] for i in range(nagents)]
        rews = None

        env.render('human')

        for t_i in range(config.episode_length):
            print(f'{t_i} / {config.episode_length}')
            calc_start = time.time()
            nagents = len(obs)
            torch_obs = [Variable(torch.Tensor(obs[i]).view(1, -1),
                                  requires_grad=False)
                         for i in range(nagents)]
            _, _, _, mean_list = ppo.step(torch_obs, act_hidden, crt_hidden)
            agent_actions_list = [a.data.cpu().numpy() for a in mean_list]
            clipped_action_list = [np.clip(a, -1, 1) for a in agent_actions_list]
            actions = [ac.flatten() for ac in clipped_action_list]
            for i in range(len(dones)):
                if dones[i]:
                    env.agents[i].movable = False
                else:
                    env.agents[i].trajectory.append(np.copy(env.agents[i].state.p_pos))
            obs, rewards, dones, infos = env.step(actions)
            if rews is None:
                rews = np.zeros(len(rewards))
            rews += np.array(rewards)
            if config.save_gifs:
                frames.append(env.render('rgb_array')[0])
            calc_end = time.time()
            elapsed = calc_end - calc_start
            if elapsed < ifi:
                time.sleep(ifi - elapsed)
            env.render('human')
        if config.save_gifs:
            gif_num = 0
            while (gif_path / ('%i_%i.gif' % (gif_num, ep_i))).exists():
                gif_num += 1
            imageio.mimsave(str(gif_path / ('%i_%i.gif' % (gif_num, ep_i))),
                            frames, duration=ifi)

    env.close()
Example #6
0
    from tqdm import tqdm
    import logging
    from algorithms.ppo import PPO
    from env.dist_env import BreakoutEnv

    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s|%(levelname)s|%(message)s')

    explore_steps = 512
    total_updates = 2000
    save_model_freq = 100

    env = BreakoutEnv(50002, num_envs=20)
    env_ids, states, rewards, dones = env.start()
    ppo = PPO(env.action_space,
              env.state_space,
              train_epoch=5,
              clip_schedule=lambda x: 0.2)

    nth_trajectory = 0
    while True:
        nth_trajectory += 1
        for _ in tqdm(range(explore_steps)):
            actions = ppo.get_action(np.asarray(states))
            env_ids, states, rewards, dones = env.step(env_ids, actions)

        s_batch, a_batch, r_batch, d_batch = env.get_episodes()
        logging.info(f'>>>>{env.mean_reward}, nth_trajectory{nth_trajectory}')

        ppo.update(s_batch, a_batch, r_batch, d_batch,
                   min(0.9, nth_trajectory / total_updates))
        ppo.sw.add_scalar('epreward_mean',
def main():
    # setup parameters
    args = SimpleNamespace(
        env_module="environments",
        env_name="TargetEnv-v0",
        device="cuda:0" if torch.cuda.is_available() else "cpu",
        num_parallel=100,
        vae_path="models/",
        frame_skip=1,
        seed=16,
        load_saved_model=False,
    )

    args.num_parallel *= args.frame_skip
    env = make_gym_environment(args)

    # env parameters
    args.action_size = env.action_space.shape[0]
    args.observation_size = env.observation_space.shape[0]

    # other configs
    args.save_path = os.path.join(current_dir, "con_" + args.env_name + ".pt")

    # sampling parameters
    args.num_frames = 10e7
    args.num_steps_per_rollout = env.unwrapped.max_timestep
    args.num_updates = int(args.num_frames / args.num_parallel /
                           args.num_steps_per_rollout)

    # learning parameters
    args.lr = 3e-5
    args.final_lr = 1e-5
    args.eps = 1e-5
    args.lr_decay_type = "exponential"
    args.mini_batch_size = 1000
    args.num_mini_batch = (args.num_parallel * args.num_steps_per_rollout //
                           args.mini_batch_size)

    # ppo parameters
    use_gae = True
    entropy_coef = 0.0
    value_loss_coef = 1.0
    ppo_epoch = 10
    gamma = 0.99
    gae_lambda = 0.95
    clip_param = 0.2
    max_grad_norm = 1.0

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    obs_shape = env.observation_space.shape
    obs_shape = (obs_shape[0], *obs_shape[1:])

    if args.load_saved_model:
        actor_critic = torch.load(args.save_path, map_location=args.device)
        print("Loading model:", args.save_path)
    else:
        controller = PoseVAEController(env)
        actor_critic = PoseVAEPolicy(controller)

    actor_critic = actor_critic.to(args.device)
    actor_critic.env_info = {"frame_skip": args.frame_skip}

    agent = PPO(
        actor_critic,
        clip_param,
        ppo_epoch,
        args.num_mini_batch,
        value_loss_coef,
        entropy_coef,
        lr=args.lr,
        eps=args.eps,
        max_grad_norm=max_grad_norm,
    )

    rollouts = RolloutStorage(
        args.num_steps_per_rollout,
        args.num_parallel,
        obs_shape,
        args.action_size,
        actor_critic.state_size,
    )
    obs = env.reset()
    rollouts.observations[0].copy_(obs)
    rollouts.to(args.device)

    log_path = os.path.join(current_dir,
                            "log_ppo_progress-{}".format(args.env_name))
    logger = StatsLogger(csv_path=log_path)

    for update in range(args.num_updates):

        ep_info = {"reward": []}
        ep_reward = 0

        if args.lr_decay_type == "linear":
            update_linear_schedule(agent.optimizer, update, args.num_updates,
                                   args.lr, args.final_lr)
        elif args.lr_decay_type == "exponential":
            update_exponential_schedule(agent.optimizer, update, 0.99, args.lr,
                                        args.final_lr)

        for step in range(args.num_steps_per_rollout):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.observations[step])

            obs, reward, done, info = env.step(action)
            ep_reward += reward

            end_of_rollout = info.get("reset")
            masks = (~done).float()
            bad_masks = (~(done * end_of_rollout)).float()

            if done.any():
                ep_info["reward"].append(ep_reward[done].clone())
                ep_reward *= (~done).float()  # zero out the dones
                reset_indices = env.parallel_ind_buf.masked_select(
                    done.squeeze())
                obs = env.reset(reset_indices)

            if end_of_rollout:
                obs = env.reset()

            rollouts.insert(obs, action, action_log_prob, value, reward, masks,
                            bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.observations[-1]).detach()

        rollouts.compute_returns(next_value, use_gae, gamma, gae_lambda)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        torch.save(copy.deepcopy(actor_critic).cpu(), args.save_path)

        ep_info["reward"] = torch.cat(ep_info["reward"])
        logger.log_stats(
            args,
            {
                "update": update,
                "ep_info": ep_info,
                "dist_entropy": dist_entropy,
                "value_loss": value_loss,
                "action_loss": action_loss,
            },
        )
Example #8
0
def main(_seed, _config, _run):
    args = init(_seed, _config, _run)

    env_name = args.env_name

    dummy_env = make_env(env_name, render=False)

    cleanup_log_dir(args.log_dir)
    cleanup_log_dir(args.log_dir + "_test")

    try:
        os.makedirs(args.save_dir)
    except OSError:
        pass

    torch.set_num_threads(1)

    envs = make_vec_envs(env_name, args.seed, args.num_processes, args.log_dir)
    envs.set_mirror(args.use_phase_mirror)
    test_envs = make_vec_envs(env_name, args.seed, args.num_tests,
                              args.log_dir + "_test")
    test_envs.set_mirror(args.use_phase_mirror)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    if args.use_curriculum:
        curriculum = 0
        print("curriculum", curriculum)
        envs.update_curriculum(curriculum)
    if args.use_specialist:
        specialist = 0
        print("specialist", specialist)
        envs.update_specialist(specialist)
    if args.use_threshold_sampling:
        sampling_threshold = 200
        first_sampling = False
        uniform_sampling = True
        uniform_every = 500000
        uniform_counter = 1
        evaluate_envs = make_env(env_name, render=False)
        evaluate_envs.set_mirror(args.use_phase_mirror)
        evaluate_envs.update_curriculum(0)
        prob_filter = np.zeros((11, 11))
        prob_filter[5, 5] = 1
    if args.use_adaptive_sampling:
        evaluate_envs = make_env(env_name, render=False)
        evaluate_envs.set_mirror(args.use_phase_mirror)
        evaluate_envs.update_curriculum(0)
    if args.plot_prob:
        import matplotlib.pyplot as plt
        fig = plt.figure()
        plt.show(block=False)
        ax1 = fig.add_subplot(121)
        ax2 = fig.add_subplot(122)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0], *obs_shape[1:])

    if args.load_saved_controller:
        best_model = "{}_base.pt".format(env_name)
        model_path = os.path.join(current_dir, "models", best_model)
        print("Loading model {}".format(best_model))
        actor_critic = torch.load(model_path)
        actor_critic.reset_dist()
    else:
        controller = SoftsignActor(dummy_env)
        actor_critic = Policy(controller, num_ensembles=args.num_ensembles)

    mirror_function = None
    if args.use_mirror:
        indices = dummy_env.unwrapped.get_mirror_indices()
        mirror_function = get_mirror_function(indices)

    device = "cuda:0" if args.cuda else "cpu"
    if args.cuda:
        actor_critic.cuda()

    agent = PPO(actor_critic,
                mirror_function=mirror_function,
                **args.ppo_params)

    rollouts = RolloutStorage(
        args.num_steps,
        args.num_processes,
        obs_shape,
        envs.action_space.shape[0],
        actor_critic.state_size,
    )

    current_obs = torch.zeros(args.num_processes, *obs_shape)

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    episode_rewards = deque(maxlen=args.num_processes)
    test_episode_rewards = deque(maxlen=args.num_tests)
    num_updates = int(args.num_frames) // args.num_steps // args.num_processes

    start = time.time()
    next_checkpoint = args.save_every
    max_ep_reward = float("-inf")

    logger = ConsoleCSVLogger(log_dir=args.experiment_dir,
                              console_log_interval=args.log_interval)

    update_values = False

    if args.save_sampling_prob:
        sampling_prob_list = []

    for j in range(num_updates):

        if args.lr_decay_type == "linear":
            scheduled_lr = linear_decay(j, num_updates, args.lr, final_value=0)
        elif args.lr_decay_type == "exponential":
            scheduled_lr = exponential_decay(j,
                                             0.99,
                                             args.lr,
                                             final_value=3e-5)
        else:
            scheduled_lr = args.lr

        set_optimizer_lr(agent.optimizer, scheduled_lr)

        ac_state_dict = copy.deepcopy(actor_critic).cpu().state_dict()

        if update_values and args.use_threshold_sampling:
            envs.update_curriculum(5)
        elif (not update_values
              ) and args.use_threshold_sampling and first_sampling:
            envs.update_specialist(0)

        if args.use_threshold_sampling and not uniform_sampling:
            obs = evaluate_envs.reset()
            yaw_size = dummy_env.yaw_samples.shape[0]
            pitch_size = dummy_env.pitch_samples.shape[0]
            total_metric = torch.zeros(1, yaw_size * pitch_size).to(device)
            evaluate_counter = 0
            while True:
                obs = torch.from_numpy(obs).float().unsqueeze(0).to(device)
                with torch.no_grad():
                    _, action, _, _ = actor_critic.act(obs,
                                                       None,
                                                       None,
                                                       deterministic=True)
                cpu_actions = action.squeeze().cpu().numpy()
                obs, reward, done, info = evaluate_envs.step(cpu_actions)
                if done:
                    obs = evaluate_envs.reset()
                if evaluate_envs.update_terrain:
                    evaluate_counter += 1
                    temp_states = evaluate_envs.create_temp_states()
                    with torch.no_grad():
                        temp_states = torch.from_numpy(temp_states).float().to(
                            device)
                        value_samples = actor_critic.get_ensemble_values(
                            temp_states, None, None)
                        #yaw_size = dummy_env.yaw_samples.shape[0]
                        mean = value_samples.mean(dim=-1)
                        #mean = value_samples.min(dim=-1)[0]
                        metric = mean.clone()
                        metric = metric.view(yaw_size, pitch_size)
                        #metric = metric / (metric.abs().max())
                        metric = metric.view(1, yaw_size * pitch_size)
                        total_metric += metric
                if evaluate_counter >= 5:
                    total_metric /= (total_metric.abs().max())
                    #total_metric[total_metric < 0.7] = 0
                    print("metric", total_metric)
                    sampling_probs = (
                        -10 * (total_metric - args.curriculum_threshold).abs()
                    ).softmax(dim=1).view(
                        yaw_size, pitch_size
                    )  #threshold1:150, 0.9 l2, threshold2: 10, 0.85 l1, threshold3: 10, 0.85, l1, 0.40 gap
                    #threshold 4: 20, 0.85, l1, yaw 10
                    if args.save_sampling_prob:
                        sampling_prob_list.append(sampling_probs.cpu().numpy())
                    sample_probs = np.zeros(
                        (args.num_processes, yaw_size, pitch_size))
                    #print("prob", sampling_probs)
                    for i in range(args.num_processes):
                        sample_probs[i, :, :] = np.copy(
                            sampling_probs.cpu().numpy().astype(np.float64))
                    envs.update_sample_prob(sample_probs)
                    break
        elif args.use_threshold_sampling and uniform_sampling:
            envs.update_curriculum(5)
        # if args.use_threshold_sampling and not uniform_sampling:
        #     obs = evaluate_envs.reset()
        #     yaw_size = dummy_env.yaw_samples.shape[0]
        #     pitch_size = dummy_env.pitch_samples.shape[0]
        #     r_size = dummy_env.r_samples.shape[0]
        #     total_metric = torch.zeros(1, yaw_size * pitch_size * r_size).to(device)
        #     evaluate_counter = 0
        #     while True:
        #         obs = torch.from_numpy(obs).float().unsqueeze(0).to(device)
        #         with torch.no_grad():
        #             _, action, _, _ = actor_critic.act(
        #             obs, None, None, deterministic=True
        #             )
        #         cpu_actions = action.squeeze().cpu().numpy()
        #         obs, reward, done, info = evaluate_envs.step(cpu_actions)
        #         if done:
        #             obs = evaluate_envs.reset()
        #         if evaluate_envs.update_terrain:
        #             evaluate_counter += 1
        #             temp_states = evaluate_envs.create_temp_states()
        #             with torch.no_grad():
        #                 temp_states = torch.from_numpy(temp_states).float().to(device)
        #                 value_samples = actor_critic.get_ensemble_values(temp_states, None, None)
        #                 mean = value_samples.mean(dim=-1)
        #                 #mean = value_samples.min(dim=-1)[0]
        #                 metric = mean.clone()
        #                 metric = metric.view(yaw_size, pitch_size, r_size)
        #                 #metric = metric / (metric.abs().max())
        #                 metric = metric.view(1, yaw_size*pitch_size*r_size)
        #                 total_metric += metric
        #         if evaluate_counter >= 5:
        #             total_metric /= (total_metric.abs().max())
        #             #total_metric[total_metric < 0.7] = 0
        #             #print("metric", total_metric)
        #             sampling_probs = (-10*(total_metric-0.85).abs()).softmax(dim=1).view(yaw_size, pitch_size, r_size) #threshold1:150, 0.9 l2, threshold2: 10, 0.85 l1, threshold3: 10, 0.85, l1, 0.40 gap
        #             #threshold 4: 3d grid, 10, 0.85, l1
        #             sample_probs = np.zeros((args.num_processes, yaw_size, pitch_size, r_size))
        #             #print("prob", sampling_probs)
        #             for i in range(args.num_processes):
        #                 sample_probs[i, :, :, :] = np.copy(sampling_probs.cpu().numpy().astype(np.float64))
        #             envs.update_sample_prob(sample_probs)
        #             break
        # elif args.use_threshold_sampling and uniform_sampling:
        #     envs.update_curriculum(5)

        if args.use_adaptive_sampling:
            obs = evaluate_envs.reset()
            yaw_size = dummy_env.yaw_samples.shape[0]
            pitch_size = dummy_env.pitch_samples.shape[0]
            total_metric = torch.zeros(1, yaw_size * pitch_size).to(device)
            evaluate_counter = 0
            while True:
                obs = torch.from_numpy(obs).float().unsqueeze(0).to(device)
                with torch.no_grad():
                    _, action, _, _ = actor_critic.act(obs,
                                                       None,
                                                       None,
                                                       deterministic=True)
                cpu_actions = action.squeeze().cpu().numpy()
                obs, reward, done, info = evaluate_envs.step(cpu_actions)
                if done:
                    obs = evaluate_envs.reset()
                if evaluate_envs.update_terrain:
                    evaluate_counter += 1
                    temp_states = evaluate_envs.create_temp_states()
                    with torch.no_grad():
                        temp_states = torch.from_numpy(temp_states).float().to(
                            device)
                        value_samples = actor_critic.get_ensemble_values(
                            temp_states, None, None)
                        mean = value_samples.mean(dim=-1)
                        metric = mean.clone()
                        metric = metric.view(yaw_size, pitch_size)
                        #metric = metric / metric.abs().max()
                        metric = metric.view(1, yaw_size * pitch_size)
                        total_metric += metric
                        # sampling_probs = (-30*metric).softmax(dim=1).view(size, size)
                        # sample_probs = np.zeros((args.num_processes, size, size))
                        # for i in range(args.num_processes):
                        #     sample_probs[i, :, :] = np.copy(sampling_probs.cpu().numpy().astype(np.float64))
                        # envs.update_sample_prob(sample_probs)
                if evaluate_counter >= 5:
                    total_metric /= (total_metric.abs().max())
                    print("metric", total_metric)
                    sampling_probs = (-10 * total_metric).softmax(dim=1).view(
                        yaw_size, pitch_size)
                    sample_probs = np.zeros(
                        (args.num_processes, yaw_size, pitch_size))
                    for i in range(args.num_processes):
                        sample_probs[i, :, :] = np.copy(
                            sampling_probs.cpu().numpy().astype(np.float64))
                    envs.update_sample_prob(sample_probs)
                    break

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step],
                    rollouts.states[step],
                    rollouts.masks[step],
                    deterministic=update_values)
            cpu_actions = action.squeeze(1).cpu().numpy()

            obs, reward, done, infos = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()

            if args.plot_prob and step == 0:
                temp_states = envs.create_temp_states()
                with torch.no_grad():
                    temp_states = torch.from_numpy(temp_states).float().to(
                        device)
                    value_samples = actor_critic.get_value(
                        temp_states, None, None)
                size = dummy_env.yaw_samples.shape[0]
                v = value_samples.mean(dim=0).view(size, size).cpu().numpy()
                vs = value_samples.var(dim=0).view(size, size).cpu().numpy()
                ax1.pcolormesh(v)
                ax2.pcolormesh(vs)
                print(np.round(v, 2))
                fig.canvas.draw()

            # if args.use_adaptive_sampling:
            #     temp_states = envs.create_temp_states()
            #     with torch.no_grad():
            #         temp_states = torch.from_numpy(temp_states).float().to(device)
            #         value_samples = actor_critic.get_value(temp_states, None, None)

            #     size = dummy_env.yaw_samples.shape[0]
            #     sample_probs = (-value_samples / 5).softmax(dim=1).view(args.num_processes, size, size)
            #     envs.update_sample_prob(sample_probs.cpu().numpy())

            # if args.use_threshold_sampling and not uniform_sampling:
            #     temp_states = envs.create_temp_states()
            #     with torch.no_grad():
            #         temp_states = torch.from_numpy(temp_states).float().to(device)
            #         value_samples = actor_critic.get_ensemble_values(temp_states, None, None)
            #     size = dummy_env.yaw_samples.shape[0]
            #     mean = value_samples.mean(dim=-1)
            #     std = value_samples.std(dim=-1)

            #using std
            # metric = std.clone()
            # metric = metric.view(args.num_processes, size, size)
            # value_filter = torch.ones(args.num_processes, 11, 11).to(device) * -1e5
            # value_filter[:, 5 - curriculum: 5 + curriculum + 1, 5 - curriculum: 5 + curriculum + 1] = 0
            # metric = metric / metric.max() + value_filter
            # metric = metric.view(args.num_processes, size*size)
            # sample_probs = (30*metric).softmax(dim=1).view(args.num_processes, size, size)

            #using value estimate
            # metric = mean.clone()
            # metric = metric.view(args.num_processes, size, size)
            # value_filter = torch.ones(args.num_processes, 11, 11).to(device) * -1e5
            # value_filter[:, 5 - curriculum: 5 + curriculum + 1, 5 - curriculum: 5 + curriculum + 1] = 0
            # metric = metric / metric.abs().max() - value_filter
            # metric = metric.view(args.num_processes, size*size)
            # sample_probs = (-30*metric).softmax(dim=1).view(args.num_processes, size, size)

            # if args.plot_prob and step == 0:
            #     #print(sample_probs.cpu().numpy()[0, :, :])
            #     ax.pcolormesh(sample_probs.cpu().numpy()[0, :, :])
            #     print(np.round(sample_probs.cpu().numpy()[0, :, :], 4))
            #     fig.canvas.draw()
            # envs.update_sample_prob(sample_probs.cpu().numpy())

            #using value threshold
            # metric = mean.clone()
            # metric = metric.view(args.num_processes, size, size)
            # metric = metric / metric.abs().max()# - value_filter
            # metric = metric.view(args.num_processes, size*size)
            # sample_probs = (-30*(metric-0.8)**2).softmax(dim=1).view(args.num_processes, size, size)

            # if args.plot_prob and step == 0:
            #     ax.pcolormesh(sample_probs.cpu().numpy()[0, :, :])
            #     print(np.round(sample_probs.cpu().numpy()[0, :, :], 4))
            #     fig.canvas.draw()
            # envs.update_sample_prob(sample_probs.cpu().numpy())

            bad_masks = np.ones((args.num_processes, 1))
            for p_index, info in enumerate(infos):
                keys = info.keys()
                # This information is added by algorithms.utils.TimeLimitMask
                if "bad_transition" in keys:
                    bad_masks[p_index] = 0.0
                # This information is added by baselines.bench.Monitor
                if "episode" in keys:
                    episode_rewards.append(info["episode"]["r"])

            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.from_numpy(bad_masks)

            update_current_obs(obs)
            rollouts.insert(
                current_obs,
                states,
                action,
                action_log_prob,
                value,
                reward,
                masks,
                bad_masks,
            )

        obs = test_envs.reset()
        if args.use_threshold_sampling:
            if uniform_counter % uniform_every == 0:
                uniform_sampling = True
                uniform_counter = 0
            else:
                uniform_sampling = False
            uniform_counter += 1
            if uniform_sampling:
                envs.update_curriculum(5)
                print("uniform")

        #print("max_step", dummy_env._max_episode_steps)
        for step in range(dummy_env._max_episode_steps):
            # Sample actions
            with torch.no_grad():
                obs = torch.from_numpy(obs).float().to(device)
                _, action, _, _ = actor_critic.act(obs,
                                                   None,
                                                   None,
                                                   deterministic=True)
            cpu_actions = action.squeeze(1).cpu().numpy()

            obs, reward, done, infos = test_envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()

            for p_index, info in enumerate(infos):
                keys = info.keys()
                # This information is added by baselines.bench.Monitor
                if "episode" in keys:
                    #print(info["episode"]["r"])
                    test_episode_rewards.append(info["episode"]["r"])

        if args.use_curriculum and np.mean(
                episode_rewards) > 1000 and curriculum <= 4:
            curriculum += 1
            print("curriculum", curriculum)
            envs.update_curriculum(curriculum)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.observations[-1],
                                                rollouts.states[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda)

        if update_values:
            value_loss = agent.update_values(rollouts)
        else:
            value_loss, action_loss, dist_entropy = agent.update(rollouts)
        #update_values = (not update_values)

        rollouts.after_update()

        frame_count = (j + 1) * args.num_steps * args.num_processes
        if (frame_count >= next_checkpoint
                or j == num_updates - 1) and args.save_dir != "":
            model_name = "{}_{:d}.pt".format(env_name, int(next_checkpoint))
            next_checkpoint += args.save_every
        else:
            model_name = "{}_latest.pt".format(env_name)

        if args.save_sampling_prob:
            import pickle
            with open('{}_sampling_prob85.pkl'.format(env_name), 'wb') as fp:
                pickle.dump(sampling_prob_list, fp)

        # A really ugly way to save a model to CPU
        save_model = actor_critic
        if args.cuda:
            save_model = copy.deepcopy(actor_critic).cpu()

        if args.use_specialist and np.mean(
                episode_rewards) > 1000 and specialist <= 4:
            specialist_name = "{}_specialist_{:d}.pt".format(
                env_name, int(specialist))
            specialist_model = actor_critic
            if args.cuda:
                specialist_model = copy.deepcopy(actor_critic).cpu()
            torch.save(specialist_model,
                       os.path.join(args.save_dir, specialist_name))
            specialist += 1
            envs.update_specialist(specialist)
        # if args.use_threshold_sampling and np.mean(episode_rewards) > 1000 and curriculum <= 4:
        #     first_sampling = False
        #     curriculum += 1
        #     print("curriculum", curriculum)
        #     envs.update_curriculum(curriculum)
        #     prob_filter[5-curriculum:5+curriculum+1, 5-curriculum:5+curriculum+1] = 1

        torch.save(save_model, os.path.join(args.save_dir, model_name))

        if len(episode_rewards) > 1 and np.mean(
                episode_rewards) > max_ep_reward:
            model_name = "{}_best.pt".format(env_name)
            max_ep_reward = np.mean(episode_rewards)
            torch.save(save_model, os.path.join(args.save_dir, model_name))

        if len(episode_rewards) > 1:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            logger.log_epoch({
                "iter": j + 1,
                "total_num_steps": total_num_steps,
                "fps": int(total_num_steps / (end - start)),
                "entropy": dist_entropy,
                "value_loss": value_loss,
                "action_loss": action_loss,
                "stats": {
                    "rew": episode_rewards
                },
                "test_stats": {
                    "rew": test_episode_rewards
                },
            })
Example #9
0
    def initialise_policy(self):

        # variables for task encoder (used for oracle)
        state_dim = self.envs.observation_space.shape[0]

        # TODO: this isn't ideal, find a nicer way to get the task dimension!
        if 'BeliefOracle' in self.args.env_name:
            task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \
                       gym.make(self.args.env_name.replace('BeliefOracle', '')).observation_space.shape[0]
            latent_dim = self.args.latent_dim
            state_embedding_size = self.args.state_embedding_size
            use_task_encoder = True
        elif 'Oracle' in self.args.env_name:
            task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \
                       gym.make(self.args.env_name.replace('Oracle', '')).observation_space.shape[0]
            latent_dim = self.args.latent_dim
            state_embedding_size = self.args.state_embedding_size
            use_task_encoder = True
        else:
            task_dim = latent_dim = state_embedding_size = 0
            use_task_encoder = False

        # initialise rollout storage for the policy
        self.policy_storage = OnlineStorage(
            self.args,
            self.args.policy_num_steps,
            self.args.num_processes,
            self.args.obs_dim,
            self.args.act_space,
            hidden_size=0,
            latent_dim=self.args.latent_dim,
            normalise_observations=self.args.norm_obs_for_policy,
            normalise_rewards=self.args.norm_rew_for_policy,
        )

        if hasattr(self.envs.action_space, 'low'):
            action_low = self.envs.action_space.low
            action_high = self.envs.action_space.high
        else:
            action_low = action_high = None

        # initialise policy network
        policy_net = Policy(
            # general
            state_dim=int(self.args.condition_policy_on_state) * state_dim,
            action_space=self.envs.action_space,
            init_std=self.args.policy_init_std,
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            use_task_encoder=use_task_encoder,
            # task encoding things (for oracle)
            task_dim=task_dim,
            latent_dim=latent_dim,
            state_embed_dim=state_embedding_size,
            #
            normalise_actions=self.args.normalise_actions,
            action_low=action_low,
            action_high=action_high,
        ).to(device)

        # initialise policy
        if self.args.policy == 'a2c':
            # initialise policy trainer (A2C)
            self.policy = A2C(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                alpha=self.args.a2c_alpha,
            )
        elif self.args.policy == 'ppo':
            # initialise policy network
            self.policy = PPO(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                ppo_epoch=self.args.ppo_num_epochs,
                num_mini_batch=self.args.ppo_num_minibatch,
                use_huber_loss=self.args.ppo_use_huberloss,
                use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
                clip_param=self.args.ppo_clip_param,
            )
        else:
            raise NotImplementedError
Example #10
0
class Learner:
    """
    Learner (no meta-learning), can be used to train Oracle policies.
    """
    def __init__(self, args):
        self.args = args

        # make sure everything has the same seed
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            log_dir=args.agent_log_dir,
            device=device,
            allow_early_resets=False,
            episodes_per_task=self.args.max_rollouts_per_task,
            obs_rms=None,
            ret_rms=None,
        )

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = self.envs._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task

        # calculate number of meta updates
        self.args.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes

        # get action / observation dimensions
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.envs.action_space.shape[0]
        self.args.obs_dim = self.envs.observation_space.shape[0]
        self.args.num_states = self.envs.num_states if str.startswith(
            self.args.env_name, 'Grid') else None
        self.args.act_space = self.envs.action_space

        self.initialise_policy()

        # count number of frames and updates
        self.frames = 0
        self.iter_idx = 0

    def initialise_policy(self):

        # variables for task encoder (used for oracle)
        state_dim = self.envs.observation_space.shape[0]

        # TODO: this isn't ideal, find a nicer way to get the task dimension!
        if 'BeliefOracle' in self.args.env_name:
            task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \
                       gym.make(self.args.env_name.replace('BeliefOracle', '')).observation_space.shape[0]
            latent_dim = self.args.latent_dim
            state_embedding_size = self.args.state_embedding_size
            use_task_encoder = True
        elif 'Oracle' in self.args.env_name:
            task_dim = gym.make(self.args.env_name).observation_space.shape[0] - \
                       gym.make(self.args.env_name.replace('Oracle', '')).observation_space.shape[0]
            latent_dim = self.args.latent_dim
            state_embedding_size = self.args.state_embedding_size
            use_task_encoder = True
        else:
            task_dim = latent_dim = state_embedding_size = 0
            use_task_encoder = False

        # initialise rollout storage for the policy
        self.policy_storage = OnlineStorage(
            self.args,
            self.args.policy_num_steps,
            self.args.num_processes,
            self.args.obs_dim,
            self.args.act_space,
            hidden_size=0,
            latent_dim=self.args.latent_dim,
            normalise_observations=self.args.norm_obs_for_policy,
            normalise_rewards=self.args.norm_rew_for_policy,
        )

        if hasattr(self.envs.action_space, 'low'):
            action_low = self.envs.action_space.low
            action_high = self.envs.action_space.high
        else:
            action_low = action_high = None

        # initialise policy network
        policy_net = Policy(
            # general
            state_dim=int(self.args.condition_policy_on_state) * state_dim,
            action_space=self.envs.action_space,
            init_std=self.args.policy_init_std,
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            use_task_encoder=use_task_encoder,
            # task encoding things (for oracle)
            task_dim=task_dim,
            latent_dim=latent_dim,
            state_embed_dim=state_embedding_size,
            #
            normalise_actions=self.args.normalise_actions,
            action_low=action_low,
            action_high=action_high,
        ).to(device)

        # initialise policy
        if self.args.policy == 'a2c':
            # initialise policy trainer (A2C)
            self.policy = A2C(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                alpha=self.args.a2c_alpha,
            )
        elif self.args.policy == 'ppo':
            # initialise policy network
            self.policy = PPO(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                ppo_epoch=self.args.ppo_num_epochs,
                num_mini_batch=self.args.ppo_num_minibatch,
                use_huber_loss=self.args.ppo_use_huberloss,
                use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
                clip_param=self.args.ppo_clip_param,
            )
        else:
            raise NotImplementedError

    def train(self):
        """
        Given some stream of environments and a logger (tensorboard),
        (meta-)trains the policy.
        """

        start_time = time.time()

        # reset environments
        (prev_obs_raw, prev_obs_normalised) = self.envs.reset()
        prev_obs_raw = prev_obs_raw.to(device)
        prev_obs_normalised = prev_obs_normalised.to(device)

        # insert initial observation / embeddings to rollout storage
        self.policy_storage.prev_obs_raw[0].copy_(prev_obs_raw)
        self.policy_storage.prev_obs_normalised[0].copy_(prev_obs_normalised)
        self.policy_storage.to(device)

        for self.iter_idx in range(self.args.num_updates):

            # check if we flushed the policy storage
            assert len(self.policy_storage.latent_mean) == 0

            # rollouts policies for a few steps
            for step in range(self.args.policy_num_steps):

                # sample actions from policy
                with torch.no_grad():
                    value, action, action_log_prob = utl.select_action(
                        policy=self.policy,
                        args=self.args,
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        deterministic=False)

                # observe reward and next obs
                (next_obs_raw, next_obs_normalised), (
                    rew_raw, rew_normalised), done, infos = utl.env_step(
                        self.envs, action)
                action = action.float()

                # create mask for episode ends
                masks_done = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done]).to(device)
                # bad_mask is true if episode ended because time limit was reached
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos]).to(device)

                # add the obs before reset to the policy storage
                self.policy_storage.next_obs_raw[step] = next_obs_raw.clone()
                self.policy_storage.next_obs_normalised[
                    step] = next_obs_normalised.clone()

                # reset environments that are done
                done_indices = np.argwhere(done.flatten()).flatten()
                if len(done_indices) == self.args.num_processes:
                    [next_obs_raw, next_obs_normalised] = self.envs.reset()
                    if not self.args.sample_embeddings:
                        latent_sample = latent_sample
                else:
                    for i in done_indices:
                        [next_obs_raw[i],
                         next_obs_normalised[i]] = self.envs.reset(index=i)
                        if not self.args.sample_embeddings:
                            latent_sample[i] = latent_sample[i]

                # add experience to policy buffer
                self.policy_storage.insert(
                    obs_raw=next_obs_raw.clone(),
                    obs_normalised=next_obs_normalised.clone(),
                    actions=action.clone(),
                    action_log_probs=action_log_prob.clone(),
                    rewards_raw=rew_raw.clone(),
                    rewards_normalised=rew_normalised.clone(),
                    value_preds=value.clone(),
                    masks=masks_done.clone(),
                    bad_masks=bad_masks.clone(),
                    done=torch.from_numpy(np.array(
                        done, dtype=float)).unsqueeze(1).clone(),
                )

                prev_obs_normalised = next_obs_normalised
                prev_obs_raw = next_obs_raw

                self.frames += self.args.num_processes

            # --- UPDATE ---

            train_stats = self.update(prev_obs_normalised if self.args.
                                      norm_obs_for_policy else prev_obs_raw)

            # log
            run_stats = [action, action_log_prob, value]
            if train_stats is not None:
                self.log(run_stats, train_stats, start_time)

            # clean up after update
            self.policy_storage.after_update()

    def get_value(self, obs):
        obs = utl.get_augmented_obs(args=self.args, obs=obs)
        return self.policy.actor_critic.get_value(obs).detach()

    def update(self, obs):
        """
        Meta-update.
        Here the policy is updated for good average performance across tasks.
        :return:    policy_train_stats which are: value_loss_epoch, action_loss_epoch, dist_entropy_epoch, loss_epoch
        """
        # bootstrap next value prediction
        with torch.no_grad():
            next_value = self.get_value(obs)

        # compute returns for current rollouts
        self.policy_storage.compute_returns(
            next_value,
            self.args.policy_use_gae,
            self.args.policy_gamma,
            self.args.policy_tau,
            use_proper_time_limits=self.args.use_proper_time_limits)

        policy_train_stats = self.policy.update(
            args=self.args, policy_storage=self.policy_storage)

        return policy_train_stats, None

    def log(self, run_stats, train_stats, start):
        """
        Evaluate policy, save model, write to tensorboard logger.
        """
        train_stats, meta_train_stats = train_stats

        # --- visualise behaviour of policy ---

        if self.iter_idx % self.args.vis_interval == 0:
            obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None
            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None

            utl_eval.visualise_behaviour(
                args=self.args,
                policy=self.policy,
                image_folder=self.logger.full_output_folder,
                iter_idx=self.iter_idx,
                obs_rms=obs_rms,
                ret_rms=ret_rms,
            )

        # --- evaluate policy ----

        if self.iter_idx % self.args.eval_interval == 0:

            obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None
            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None

            returns_per_episode = utl_eval.evaluate(args=self.args,
                                                    policy=self.policy,
                                                    obs_rms=obs_rms,
                                                    ret_rms=ret_rms,
                                                    iter_idx=self.iter_idx)

            # log the average return across tasks (=processes)
            returns_avg = returns_per_episode.mean(dim=0)
            returns_std = returns_per_episode.std(dim=0)
            for k in range(len(returns_avg)):
                self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1),
                                returns_avg[k], self.iter_idx)
                self.logger.add(
                    'return_avg_per_frame/episode_{}'.format(k + 1),
                    returns_avg[k], self.frames)
                self.logger.add('return_std_per_iter/episode_{}'.format(k + 1),
                                returns_std[k], self.iter_idx)
                self.logger.add(
                    'return_std_per_frame/episode_{}'.format(k + 1),
                    returns_std[k], self.frames)

            print(
                "Updates {}, num timesteps {}, FPS {} \n Mean return (train): {:.5f} \n"
                .format(self.iter_idx, self.frames,
                        int(self.frames / (time.time() - start)),
                        returns_avg[-1].item()))

        # save model
        if self.iter_idx % self.args.save_interval == 0:
            save_path = os.path.join(self.logger.full_output_folder, 'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            torch.save(
                self.policy.actor_critic,
                os.path.join(save_path, "policy{0}.pt".format(self.iter_idx)))

            # save normalisation params of envs
            if self.args.norm_rew_for_policy:
                # save rolling mean and std
                rew_rms = self.envs.venv.ret_rms
                utl.save_obj(rew_rms, save_path,
                             "env_rew_rms{0}.pkl".format(self.iter_idx))
            if self.args.norm_obs_for_policy:
                obs_rms = self.envs.venv.obs_rms
                utl.save_obj(obs_rms, save_path,
                             "env_obs_rms{0}.pkl".format(self.iter_idx))

        # --- log some other things ---

        if self.iter_idx % self.args.log_interval == 0:
            self.logger.add('policy_losses/value_loss', train_stats[0],
                            self.iter_idx)
            self.logger.add('policy_losses/action_loss', train_stats[1],
                            self.iter_idx)
            self.logger.add('policy_losses/dist_entropy', train_stats[2],
                            self.iter_idx)
            self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx)

            # writer.add_scalar('policy/action', action.mean(), j)
            self.logger.add('policy/action', run_stats[0][0].float().mean(),
                            self.iter_idx)
            if hasattr(self.policy.actor_critic, 'logstd'):
                self.logger.add('policy/action_logstd',
                                self.policy.actor_critic.dist.logstd.mean(),
                                self.iter_idx)
            self.logger.add('policy/action_logprob', run_stats[1].mean(),
                            self.iter_idx)
            self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx)

            param_list = list(self.policy.actor_critic.parameters())
            param_mean = np.mean(
                [param_list[i].data.mean() for i in range(len(param_list))])
            param_grad_mean = np.mean(
                [param_list[i].grad.mean() for i in range(len(param_list))])
            self.logger.add('weights/policy', param_mean, self.iter_idx)
            self.logger.add('weights/policy_std', param_list[0].data.mean(),
                            self.iter_idx)
            self.logger.add('gradients/policy', param_grad_mean, self.iter_idx)
            self.logger.add('gradients/policy_std', param_list[0].grad.mean(),
                            self.iter_idx)
Example #11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument('--model',
                        type=str,
                        default='ppo',
                        help='the model to use for training. {ppo, ppo_aup}')
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # --- ARGUMENTS ---

    if model == 'ppo':
        args = args_ppo.get_args(rest_args)
    elif model == 'ppo_aup':
        args = args_ppo_aup.get_args(rest_args)
    else:
        raise NotImplementedError

    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model

    # warnings
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError(
                'If you want fully deterministic code, run it with num_processes=1.'
                'Warning: This will slow things down and might break A2C if '
                'policy_num_steps < env._max_episode_steps.')

    # --- TRAINING ---
    print("Setting up wandb logging.")

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(project=args.proj_name,
               name=args.run_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    # make directory for saving models
    save_dir = os.path.join(wandb.run.dir, 'models')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    print("Setting up Environments.")
    # initialise environments for training
    train_envs = make_vec_envs(env_name=args.env_name,
                               start_level=args.train_start_level,
                               num_levels=args.train_num_levels,
                               distribution_mode=args.distribution_mode,
                               paint_vel_info=args.paint_vel_info,
                               num_processes=args.num_processes,
                               num_frame_stack=args.num_frame_stack,
                               device=device)
    # initialise environments for evaluation
    eval_envs = make_vec_envs(env_name=args.env_name,
                              start_level=0,
                              num_levels=0,
                              distribution_mode=args.distribution_mode,
                              paint_vel_info=args.paint_vel_info,
                              num_processes=args.num_processes,
                              num_frame_stack=args.num_frame_stack,
                              device=device)
    _ = eval_envs.reset()

    print("Setting up Actor-Critic model and Training algorithm.")
    # initialise policy network
    actor_critic = ACModel(obs_shape=train_envs.observation_space.shape,
                           action_space=train_envs.action_space,
                           hidden_size=args.hidden_size).to(device)

    # initialise policy training algorithm
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy training algorithm
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=train_envs.observation_space.shape,
                              action_space=train_envs.action_space)

    # initialise Q_aux function(s) for AUP
    if args.use_aup:
        print("Initialising Q_aux models.")
        q_aux = [
            QModel(obs_shape=train_envs.observation_space.shape,
                   action_space=train_envs.action_space,
                   hidden_size=args.hidden_size).to(device)
            for _ in range(args.num_q_aux)
        ]
        if args.num_q_aux == 1:
            # load weights to model
            path = args.q_aux_dir + "0.pt"
            q_aux[0].load_state_dict(torch.load(path))
            q_aux[0].eval()
        else:
            # get max number of q_aux functions to choose from
            args.max_num_q_aux = os.listdir(args.q_aux_dir)
            q_aux_models = random.sample(list(range(0, args.max_num_q_aux)),
                                         args.num_q_aux)
            # load weights to models
            for i, model in enumerate(q_aux):
                path = args.q_aux_dir + str(q_aux_models[i]) + ".pt"
                model.load_state_dict(torch.load(path))
                model.eval()

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    # update wandb args
    wandb.config.update(args)

    update_start_time = time.time()
    # reset environments
    obs = train_envs.reset()  # obs.shape = (num_processes,C,H,W)
    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames) // args.num_batch

    # define AUP coefficient
    if args.use_aup:
        aup_coef = args.aup_coef_start
        aup_linear_increase_val = math.exp(
            math.log(args.aup_coef_end / args.aup_coef_start) /
            args.num_updates)

    print("Training beginning.")
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        if args.use_aup:
            aup_measures = defaultdict(list)

        # rollout policy to collect num_batch of experience and place in storage
        for step in range(args.policy_num_steps):

            # sample actions from policy
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step])

            # observe rewards and next obs
            obs, reward, done, infos = train_envs.step(action)

            # calculate AUP reward
            if args.use_aup:
                intrinsic_reward = torch.zeros_like(reward)
                with torch.no_grad():
                    for model in q_aux:
                        # get action-values
                        action_values = model.get_action_value(
                            rollouts.obs[step])
                        # get action-value for action taken
                        action_value = torch.sum(
                            action_values * torch.nn.functional.one_hot(
                                action,
                                num_classes=train_envs.action_space.n).squeeze(
                                    dim=1),
                            dim=1)
                        # calculate the penalty
                        intrinsic_reward += torch.abs(
                            action_value.unsqueeze(dim=1) -
                            action_values[:, 4].unsqueeze(dim=1))
                intrinsic_reward /= args.num_q_aux
                # add intrinsic reward to the extrinsic reward
                reward -= aup_coef * intrinsic_reward
                # log the intrinsic reward from the first env.
                aup_measures['intrinsic_reward'].append(aup_coef *
                                                        intrinsic_reward[0, 0])
                if done[0] and infos[0]['prev_level_complete'] == 1:
                    aup_measures['episode_complete'].append(2)
                elif done[0] and infos[0]['prev_level_complete'] == 0:
                    aup_measures['episode_complete'].append(1)
                else:
                    aup_measures['episode_complete'].append(0)

            # log episode info if episode finished
            for info in infos:
                if 'episode' in info.keys():
                    episode_info_buf.append(info['episode'])

            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to storage
            rollouts.insert(obs, reward, action, value, action_log_prob, masks)

            frames += args.num_processes

        # linearly increase aup coefficient after every update
        if args.use_aup:
            aup_coef *= aup_linear_increase_val

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy training algorithm
        total_loss, value_loss, action_loss, dist_entropy = policy.update(
            rollouts)

        # clean up storage after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:

            # --- EVALUATION ---
            eval_episode_info_buf = utl_eval.evaluate(
                eval_envs=eval_envs, actor_critic=actor_critic, device=device)

            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # calculates whether the value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            if args.use_aup:
                step = frames - args.num_processes * args.policy_num_steps
                for i in range(args.policy_num_steps):
                    wandb.log(
                        {
                            'aup/intrinsic_reward':
                            aup_measures['intrinsic_reward'][i],
                            'aup/episode_complete':
                            aup_measures['episode_complete'][i]
                        },
                        step=step)
                    step += args.num_processes

            wandb.log(
                {
                    'misc/timesteps':
                    frames,
                    'misc/fps':
                    fps,
                    'misc/explained_variance':
                    float(ev),
                    'losses/total_loss':
                    total_loss,
                    'losses/value_loss':
                    value_loss,
                    'losses/action_loss':
                    action_loss,
                    'losses/dist_entropy':
                    dist_entropy,
                    'train/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r'] for episode_info in episode_info_buf
                    ]),
                    'train/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l'] for episode_info in episode_info_buf
                    ]),
                    'eval/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in eval_episode_info_buf
                    ]),
                    'eval/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in eval_episode_info_buf
                    ])
                },
                step=frames)

        # --- SAVE MODEL ---

        # save for every interval-th episode or for the last epoch
        if iter_idx != 0 and (iter_idx % args.save_interval == 0
                              or iter_idx == args.num_updates - 1):
            print("Saving Actor-Critic Model.")
            torch.save(actor_critic.state_dict(),
                       os.path.join(save_dir, "policy{0}.pt".format(iter_idx)))

    # close envs
    train_envs.close()
    eval_envs.close()

    # --- TEST ---

    if args.test:
        print("Testing beginning.")
        episodic_return = utl_test.test(args=args,
                                        actor_critic=actor_critic,
                                        device=device)

        # save returns from train and test levels to analyse using interactive mode
        train_levels = torch.arange(
            args.train_start_level,
            args.train_start_level + args.train_num_levels)
        for i, level in enumerate(train_levels):
            wandb.log({
                'test/train_levels': level,
                'test/train_returns': episodic_return[0][i]
            })
        test_levels = torch.arange(
            args.test_start_level,
            args.test_start_level + args.test_num_levels)
        for i, level in enumerate(test_levels):
            wandb.log({
                'test/test_levels': level,
                'test/test_returns': episodic_return[1][i]
            })
        # log returns from test envs
        wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[0])
        wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[1])
Example #12
0
        x = tf.layers.conv2d(x, 64, 3, 1, activation=tf.nn.relu)
        x = tf.contrib.layers.flatten(x)
        x = tf.layers.dense(x, 512, activation=tf.nn.relu)

        logit_action_probability = tf.layers.dense(
            x,
            action_space,
            kernel_initializer=tf.truncated_normal_initializer(0.0, 0.01))
        state_value = tf.squeeze(
            tf.layers.dense(
                x, 1, kernel_initializer=tf.truncated_normal_initializer()))
        return logit_action_probability, state_value

    ppo = PPO(action_space,
              obs_fn,
              model_fn,
              train_epoch=5,
              batch_size=64,
              save_path='./raiden2_model')

    env = Raiden2(6666, num_envs=1, with_stack=True)
    env_ids, states, rewards, dones = env.start()

    nth_trajectory = 0
    while True:
        nth_trajectory += 1
        for _ in tqdm(range(explore_steps)):
            actions = ppo.get_action(np.asarray(states))
            actions = [(action, 4) for action in actions]
            env_ids, states, rewards, dones = env.step(env_ids, actions)

        logging.info(f'>>>>{env.mean_reward}, nth_trajectory{nth_trajectory}')
Example #13
0
File: main.py Project: udeepam/aldm
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument(
        '--model',
        type=str,
        default='ppo',
        help='the model to use for training. {ppo, ibac, ibac_sni, dist_match}'
    )
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # --- ARGUMENTS ---

    if model == 'ppo':
        args = args_ppo.get_args(rest_args)
    elif model == 'ibac':
        args = args_ibac.get_args(rest_args)
    elif model == 'ibac_sni':
        args = args_ibac_sni.get_args(rest_args)
    elif model == 'dist_match':
        args = args_dist_match.get_args(rest_args)
    else:
        raise NotImplementedError

    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model
    args.num_train_envs = args.num_processes - args.num_val_envs if args.num_val_envs > 0 else args.num_processes

    # warnings
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError(
                'If you want fully deterministic code, run it with num_processes=1.'
                'Warning: This will slow things down and might break A2C if '
                'policy_num_steps < env._max_episode_steps.')

    elif args.num_val_envs > 0 and (args.num_val_envs >= args.num_processes
                                    or not args.percentage_levels_train < 1.0):
        raise ValueError(
            'If --args.num_val_envs>0 then you must also have'
            '--num_val_envs < --num_processes and  0 < --percentage_levels_train < 1.'
        )

    elif args.num_val_envs > 0 and not args.use_dist_matching and args.dist_matching_coef != 0:
        raise ValueError(
            'If --num_val_envs>0 and --use_dist_matching=False then you must also have'
            '--dist_matching_coef=0.')

    elif args.use_dist_matching and not args.num_val_envs > 0:
        raise ValueError(
            'If --use_dist_matching=True then you must also have'
            '0 < --num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.'
        )

    elif args.analyse_rep and not args.use_bottleneck:
        raise ValueError('If --analyse_rep=True then you must also have'
                         '--use_bottleneck=True.')

    # --- TRAINING ---
    print("Setting up wandb logging.")

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(project=args.proj_name,
               name=args.run_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    # make directory for saving models
    save_dir = os.path.join(wandb.run.dir, 'models')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    # initialise environments for training
    print("Setting up Environments.")
    if args.num_val_envs > 0:
        train_num_levels = int(args.train_num_levels *
                               args.percentage_levels_train)
        val_start_level = args.train_start_level + train_num_levels
        val_num_levels = args.train_num_levels - train_num_levels
        train_envs = make_vec_envs(env_name=args.env_name,
                                   start_level=args.train_start_level,
                                   num_levels=train_num_levels,
                                   distribution_mode=args.distribution_mode,
                                   paint_vel_info=args.paint_vel_info,
                                   num_processes=args.num_train_envs,
                                   num_frame_stack=args.num_frame_stack,
                                   device=device)
        val_envs = make_vec_envs(env_name=args.env_name,
                                 start_level=val_start_level,
                                 num_levels=val_num_levels,
                                 distribution_mode=args.distribution_mode,
                                 paint_vel_info=args.paint_vel_info,
                                 num_processes=args.num_val_envs,
                                 num_frame_stack=args.num_frame_stack,
                                 device=device)
    else:
        train_envs = make_vec_envs(env_name=args.env_name,
                                   start_level=args.train_start_level,
                                   num_levels=args.train_num_levels,
                                   distribution_mode=args.distribution_mode,
                                   paint_vel_info=args.paint_vel_info,
                                   num_processes=args.num_processes,
                                   num_frame_stack=args.num_frame_stack,
                                   device=device)
    # initialise environments for evaluation
    eval_envs = make_vec_envs(env_name=args.env_name,
                              start_level=0,
                              num_levels=0,
                              distribution_mode=args.distribution_mode,
                              paint_vel_info=args.paint_vel_info,
                              num_processes=args.num_processes,
                              num_frame_stack=args.num_frame_stack,
                              device=device)
    _ = eval_envs.reset()
    # initialise environments for analysing the representation
    if args.analyse_rep:
        analyse_rep_train1_envs, analyse_rep_train2_envs, analyse_rep_val_envs, analyse_rep_test_envs = make_rep_analysis_envs(
            args, device)

    print("Setting up Actor-Critic model and Training algorithm.")
    # initialise policy network
    actor_critic = ACModel(obs_shape=train_envs.observation_space.shape,
                           action_space=train_envs.action_space,
                           hidden_size=args.hidden_size,
                           use_bottleneck=args.use_bottleneck,
                           sni_type=args.sni_type).to(device)

    # initialise policy training algorithm
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps,
                     vib_coef=args.vib_coef,
                     sni_coef=args.sni_coef,
                     use_dist_matching=args.use_dist_matching,
                     dist_matching_loss=args.dist_matching_loss,
                     dist_matching_coef=args.dist_matching_coef,
                     num_train_envs=args.num_train_envs,
                     num_val_envs=args.num_val_envs)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy training algorithm
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=train_envs.observation_space.shape,
                              action_space=train_envs.action_space)

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    # update wandb args
    wandb.config.update(args)
    # wandb.watch(actor_critic, log="all") # to log gradients of actor-critic network

    update_start_time = time.time()

    # reset environments
    if args.num_val_envs > 0:
        obs = torch.cat([train_envs.reset(),
                         val_envs.reset()])  # obs.shape = (n_envs,C,H,W)
    else:
        obs = train_envs.reset()  # obs.shape = (n_envs,C,H,W)

    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    train_episode_info_buf = deque(maxlen=10)
    val_episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames) // args.num_batch
    print("Training beginning.")
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        # rollout policy to collect num_batch of experience and place in storage
        for step in range(args.policy_num_steps):

            # sample actions from policy
            with torch.no_grad():
                value, action, action_log_prob, _ = actor_critic.act(
                    rollouts.obs[step])

            # observe rewards and next obs
            if args.num_val_envs > 0:
                obs, reward, done, infos = train_envs.step(
                    action[:args.num_train_envs, :])
                val_obs, val_reward, val_done, val_infos = val_envs.step(
                    action[args.num_train_envs:, :])
                obs = torch.cat([obs, val_obs])
                reward = torch.cat([reward, val_reward])
                done, val_done = list(done), list(val_done)
                done.extend(val_done)
                infos.extend(val_infos)
            else:
                obs, reward, done, infos = train_envs.step(action)

            # log episode info if episode finished
            for i, info in enumerate(infos):
                if i < args.num_train_envs and 'episode' in info.keys():
                    train_episode_info_buf.append(info['episode'])
                elif i >= args.num_train_envs and 'episode' in info.keys():
                    val_episode_info_buf.append(info['episode'])

            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to storage
            rollouts.insert(obs, reward, action, value, action_log_prob, masks)

            frames += args.num_processes

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy gradient algo
        total_loss, value_loss, action_loss, dist_entropy, vib_kl, dist_matching_loss = policy.update(
            rollouts)

        # clean up storage after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:

            # --- EVALUATION ---
            eval_episode_info_buf = utl_eval.evaluate(
                eval_envs=eval_envs, actor_critic=actor_critic, device=device)

            # --- ANALYSE REPRESENTATION ---
            if args.analyse_rep:
                rep_measures = utl_rep.analyse_rep(
                    args=args,
                    train1_envs=analyse_rep_train1_envs,
                    train2_envs=analyse_rep_train2_envs,
                    val_envs=analyse_rep_val_envs,
                    test_envs=analyse_rep_test_envs,
                    actor_critic=actor_critic,
                    device=device)

            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            wandb.log(
                {
                    'misc/timesteps':
                    frames,
                    'misc/fps':
                    fps,
                    'misc/explained_variance':
                    float(ev),
                    'losses/total_loss':
                    total_loss,
                    'losses/value_loss':
                    value_loss,
                    'losses/action_loss':
                    action_loss,
                    'losses/dist_entropy':
                    dist_entropy,
                    'train/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in train_episode_info_buf
                    ]),
                    'train/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in train_episode_info_buf
                    ]),
                    'eval/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in eval_episode_info_buf
                    ]),
                    'eval/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in eval_episode_info_buf
                    ])
                },
                step=iter_idx)
            if args.use_bottleneck:
                wandb.log({'losses/vib_kl': vib_kl}, step=iter_idx)
            if args.num_val_envs > 0:
                wandb.log(
                    {
                        'losses/dist_matching_loss':
                        dist_matching_loss,
                        'val/mean_episodic_return':
                        utl_math.safe_mean([
                            episode_info['r']
                            for episode_info in val_episode_info_buf
                        ]),
                        'val/mean_episodic_length':
                        utl_math.safe_mean([
                            episode_info['l']
                            for episode_info in val_episode_info_buf
                        ])
                    },
                    step=iter_idx)
            if args.analyse_rep:
                wandb.log(
                    {
                        "analysis/" + key: val
                        for key, val in rep_measures.items()
                    },
                    step=iter_idx)

        # --- SAVE MODEL ---

        # save for every interval-th episode or for the last epoch
        if iter_idx != 0 and (iter_idx % args.save_interval == 0
                              or iter_idx == args.num_updates - 1):
            print("Saving Actor-Critic Model.")
            torch.save(actor_critic.state_dict(),
                       os.path.join(save_dir, "policy{0}.pt".format(iter_idx)))

    # close envs
    train_envs.close()
    eval_envs.close()

    # --- TEST ---

    if args.test:
        print("Testing beginning.")
        episodic_return, latents_z = utl_test.test(args=args,
                                                   actor_critic=actor_critic,
                                                   device=device)

        # save returns from train and test levels to analyse using interactive mode
        train_levels = torch.arange(
            args.train_start_level,
            args.train_start_level + args.train_num_levels)
        for i, level in enumerate(train_levels):
            wandb.log({
                'test/train_levels': level,
                'test/train_returns': episodic_return[0][i]
            })
        test_levels = torch.arange(
            args.test_start_level,
            args.test_start_level + args.test_num_levels)
        for i, level in enumerate(test_levels):
            wandb.log({
                'test/test_levels': level,
                'test/test_returns': episodic_return[1][i]
            })
        # log returns from test envs
        wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[0])
        wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[1])

        # plot latent representation
        if args.plot_pca:
            print("Plotting PCA of Latent Representation.")
            utl_rep.pca(args, latents_z)
Example #14
0
def train():

    model_dir = Path('./models') / ENV_ID
    if not model_dir.exists():
        curr_run = 'run1'
    else:
        exst_run_nums = [
            int(str(folder.name).split('run')[1])
            for folder in model_dir.iterdir()
            if str(folder.name).startswith('run')
        ]
        if len(exst_run_nums) == 0:
            curr_run = 'run1'
        else:
            curr_run = 'run%i' % (max(exst_run_nums) + 1)
    run_dir = model_dir / curr_run
    log_dir = run_dir / 'logs'
    os.makedirs(str(log_dir))
    logger = SummaryWriter(str(log_dir), max_queue=5, flush_secs=30)

    torch.manual_seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)

    env = make_parallel_env(ENV_ID, N_ROLLOUT_THREADS, RANDOM_SEED)
    ppo = PPO.init_from_env(env,
                            gamma=GAMMA,
                            lam=LAMDA,
                            lr=LR,
                            coeff_entropy=COEFF_ENTROPY,
                            batch_size=BATCH_SIZE)

    # save_dict = torch.load(model_dir / 'run1/model.pt')  # TODO init from save?
    # save_dict = save_dict['model_params']['policy']
    # ppo.policy.load_state_dict(save_dict)

    buff = []

    t = 0
    for ep_i in range(0, N_EPISODES, N_ROLLOUT_THREADS):
        print("Episodes %i-%i of %i" %
              (ep_i + 1, ep_i + 1 + N_ROLLOUT_THREADS, N_EPISODES))
        obs = env.reset()
        nagents = obs.shape[1]
        ppo.prep_rollouts(device='cpu')

        ep_rew = 0
        act_hidden = [[
            torch.zeros(N_ROLLOUT_THREADS, 128),
            torch.zeros(N_ROLLOUT_THREADS, 128)
        ] for i in range(nagents)]
        crt_hidden = [[
            torch.zeros(N_ROLLOUT_THREADS, 128),
            torch.zeros(N_ROLLOUT_THREADS, 128)
        ] for i in range(nagents)]

        for et_i in range(EPISODE_LEN):
            """
            generate actions
            """
            torch_obs = [
                Variable(torch.Tensor(np.vstack(obs[:, i])),
                         requires_grad=False) for i in range(nagents)
            ]
            prev_act_hidden = [[h.data.cpu().numpy(),
                                c.data.cpu().numpy()] for h, c in act_hidden]
            prev_crt_hidden = [[h.data.cpu().numpy(),
                                c.data.cpu().numpy()] for h, c in crt_hidden]
            v_list, agent_actions_list, logprob_list, mean_list = ppo.step(
                torch_obs, act_hidden, crt_hidden)
            v_list = [a.data.cpu().numpy() for a in v_list]
            agent_actions_list = [
                a.data.cpu().numpy() for a in agent_actions_list
            ]
            logprob_list = [a.data.cpu().numpy() for a in logprob_list]
            clipped_action_list = [
                np.clip(a, -1, 1) for a in agent_actions_list
            ]
            actions = [[ac[i] for ac in clipped_action_list]
                       for i in range(N_ROLLOUT_THREADS)]
            """
            step env
            """
            next_obs, rewards, dones, infos = env.step(actions)
            ep_rew += np.mean(rewards)
            buff.append(
                (obs, prev_act_hidden, prev_crt_hidden, agent_actions_list,
                 rewards, dones, logprob_list, v_list))
            obs = next_obs
            t += N_ROLLOUT_THREADS

            for i, done in enumerate(dones[0]):
                if done:
                    act_hidden[i] = [
                        torch.zeros(N_ROLLOUT_THREADS, 128),
                        torch.zeros(N_ROLLOUT_THREADS, 128)
                    ]
                    crt_hidden[i] = [
                        torch.zeros(N_ROLLOUT_THREADS, 128),
                        torch.zeros(N_ROLLOUT_THREADS, 128)
                    ]
                    env.envs[0].agents[i].terminate = False
            # if dones.any():
            #     break
        print('mean reward:', ep_rew)
        """
        train
        """
        next_obs = [
            Variable(torch.Tensor(np.vstack(obs[:, i])), requires_grad=False)
            for i in range(nagents)
        ]
        v_list, _, _, _ = ppo.step(next_obs, act_hidden, crt_hidden)
        v_list = [a.data.cpu().numpy() for a in v_list]

        print('updating params...')
        if USE_CUDA:
            ppo.prep_training(device='gpu')
        else:
            ppo.prep_training(device='cpu')
        ppo.update(buff=buff, last_v=v_list, to_gpu=USE_CUDA)
        ppo.prep_rollouts(device='cpu')
        buff = []

        logger.add_scalar('mean_episode_rewards', ep_rew, ep_i)

        if ep_i % SAVE_INTERVAL < N_ROLLOUT_THREADS:
            print('saving incremental...')
            os.makedirs(str(run_dir / 'incremental'), exist_ok=True)
            ppo.save(
                str(run_dir / 'incremental' / ('model_ep%i.pt' % (ep_i + 1))))
            ppo.save(str(run_dir / 'model.pt'))

    print('saving model...')
    ppo.save(str(run_dir / 'model.pt'))
    env.close()
    logger.export_scalars_to_json(str(log_dir / 'summary.json'))
    logger.close()
Example #15
0
class MetaLearner:
    """
    Meta-Learner class with the main training loop for variBAD.
    """
    def __init__(self, args):
        self.args = args
        utl.seed(self.args.seed, self.args.deterministic_execution)

        # count number of frames and number of meta-iterations
        self.frames = 0
        self.iter_idx = 0

        # initialise tensorboard logger
        self.logger = TBLogger(self.args, self.args.exp_label)

        # initialise environments
        self.envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=args.num_processes,
            gamma=args.policy_gamma,
            log_dir=args.agent_log_dir,
            device=device,
            allow_early_resets=False,
            episodes_per_task=self.args.max_rollouts_per_task,
            obs_rms=None,
            ret_rms=None,
        )

        # calculate what the maximum length of the trajectories is
        args.max_trajectory_len = self.envs._max_episode_steps
        args.max_trajectory_len *= self.args.max_rollouts_per_task

        # calculate number of meta updates
        self.args.num_updates = int(
            args.num_frames) // args.policy_num_steps // args.num_processes

        # get action / observation dimensions
        if isinstance(self.envs.action_space, gym.spaces.discrete.Discrete):
            self.args.action_dim = 1
        else:
            self.args.action_dim = self.envs.action_space.shape[0]
        self.args.obs_dim = self.envs.observation_space.shape[0]
        self.args.num_states = self.envs.num_states if str.startswith(
            self.args.env_name, 'Grid') else None
        self.args.act_space = self.envs.action_space

        self.vae = VaribadVAE(self.args, self.logger, lambda: self.iter_idx)

        self.initialise_policy()

    def initialise_policy(self):

        # initialise rollout storage for the policy
        self.policy_storage = OnlineStorage(
            self.args,
            self.args.policy_num_steps,
            self.args.num_processes,
            self.args.obs_dim,
            self.args.act_space,
            hidden_size=self.args.aggregator_hidden_size,
            latent_dim=self.args.latent_dim,
            normalise_observations=self.args.norm_obs_for_policy,
            normalise_rewards=self.args.norm_rew_for_policy,
        )

        # initialise policy network
        input_dim = self.args.obs_dim * int(
            self.args.condition_policy_on_state)
        input_dim += (
            1 + int(not self.args.sample_embeddings)) * self.args.latent_dim

        if hasattr(self.envs.action_space, 'low'):
            action_low = self.envs.action_space.low
            action_high = self.envs.action_space.high
        else:
            action_low = action_high = None

        policy_net = Policy(
            state_dim=input_dim,
            action_space=self.args.act_space,
            init_std=self.args.policy_init_std,
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            normalise_actions=self.args.normalise_actions,
            action_low=action_low,
            action_high=action_high,
        ).to(device)

        # initialise policy trainer
        if self.args.policy == 'a2c':
            self.policy = A2C(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                optimiser_vae=self.vae.optimiser_vae,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                alpha=self.args.a2c_alpha,
            )
        elif self.args.policy == 'ppo':
            self.policy = PPO(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                optimiser_vae=self.vae.optimiser_vae,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                ppo_epoch=self.args.ppo_num_epochs,
                num_mini_batch=self.args.ppo_num_minibatch,
                use_huber_loss=self.args.ppo_use_huberloss,
                use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
                clip_param=self.args.ppo_clip_param,
            )
        else:
            raise NotImplementedError

    def train(self):
        """
        Given some stream of environments and a logger (tensorboard),
        (meta-)trains the policy.
        """

        start_time = time.time()

        # reset environments
        (prev_obs_raw, prev_obs_normalised) = self.envs.reset()
        prev_obs_raw = prev_obs_raw.to(device)
        prev_obs_normalised = prev_obs_normalised.to(device)

        # insert initial observation / embeddings to rollout storage
        self.policy_storage.prev_obs_raw[0].copy_(prev_obs_raw)
        self.policy_storage.prev_obs_normalised[0].copy_(prev_obs_normalised)
        self.policy_storage.to(device)

        vae_is_pretrained = False
        for self.iter_idx in range(self.args.num_updates):

            # First, re-compute the hidden states given the current rollouts (since the VAE might've changed)
            # compute latent embedding (will return prior if current trajectory is empty)
            with torch.no_grad():
                latent_sample, latent_mean, latent_logvar, hidden_state = self.encode_running_trajectory(
                )

            # check if we flushed the policy storage
            assert len(self.policy_storage.latent_mean) == 0

            # add this initial hidden state to the policy storage
            self.policy_storage.hidden_states[0].copy_(hidden_state)
            self.policy_storage.latent_samples.append(latent_sample.clone())
            self.policy_storage.latent_mean.append(latent_mean.clone())
            self.policy_storage.latent_logvar.append(latent_logvar.clone())

            # rollout policies for a few steps
            for step in range(self.args.policy_num_steps):

                # sample actions from policy
                with torch.no_grad():
                    value, action, action_log_prob = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        deterministic=False,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar,
                    )
                # observe reward and next obs
                (next_obs_raw, next_obs_normalised), (
                    rew_raw, rew_normalised), done, infos = utl.env_step(
                        self.envs, action)
                tasks = torch.FloatTensor([info['task']
                                           for info in infos]).to(device)
                done = torch.from_numpy(np.array(
                    done, dtype=int)).to(device).float().view((-1, 1))

                # create mask for episode ends
                masks_done = torch.FloatTensor([[0.0] if done_ else [1.0]
                                                for done_ in done]).to(device)
                # bad_mask is true if episode ended because time limit was reached
                bad_masks = torch.FloatTensor(
                    [[0.0] if 'bad_transition' in info.keys() else [1.0]
                     for info in infos]).to(device)

                # compute next embedding (for next loop and/or value prediction bootstrap)
                latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(
                    encoder=self.vae.encoder,
                    next_obs=next_obs_raw,
                    action=action,
                    reward=rew_raw,
                    done=done,
                    hidden_state=hidden_state)

                # before resetting, update the embedding and add to vae buffer
                # (last state might include useful task info)
                if not (self.args.disable_decoder
                        and self.args.disable_stochasticity_in_latent):
                    self.vae.rollout_storage.insert(prev_obs_raw.clone(),
                                                    action.detach().clone(),
                                                    next_obs_raw.clone(),
                                                    rew_raw.clone(),
                                                    done.clone(),
                                                    tasks.clone())

                # add the obs before reset to the policy storage
                # (only used to recompute embeddings if rlloss is backpropagated through encoder)
                self.policy_storage.next_obs_raw[step] = next_obs_raw.clone()
                self.policy_storage.next_obs_normalised[
                    step] = next_obs_normalised.clone()

                # reset environments that are done
                done_indices = np.argwhere(
                    done.cpu().detach().flatten()).flatten()
                if len(done_indices) == self.args.num_processes:
                    [next_obs_raw, next_obs_normalised] = self.envs.reset()
                    if not self.args.sample_embeddings:
                        latent_sample = latent_sample
                else:
                    for i in done_indices:
                        [next_obs_raw[i],
                         next_obs_normalised[i]] = self.envs.reset(index=i)
                        if not self.args.sample_embeddings:
                            latent_sample[i] = latent_sample[i]

                # # add experience to policy buffer
                self.policy_storage.insert(
                    obs_raw=next_obs_raw,
                    obs_normalised=next_obs_normalised,
                    actions=action,
                    action_log_probs=action_log_prob,
                    rewards_raw=rew_raw,
                    rewards_normalised=rew_normalised,
                    value_preds=value,
                    masks=masks_done,
                    bad_masks=bad_masks,
                    done=done,
                    hidden_states=hidden_state.squeeze(0).detach(),
                    latent_sample=latent_sample.detach(),
                    latent_mean=latent_mean.detach(),
                    latent_logvar=latent_logvar.detach(),
                )

                prev_obs_normalised = next_obs_normalised
                prev_obs_raw = next_obs_raw

                self.frames += self.args.num_processes

            # --- UPDATE ---

            if self.args.precollect_len <= self.frames:
                # check if we are pre-training the VAE
                if self.args.pretrain_len > 0 and not vae_is_pretrained:
                    for _ in range(self.args.pretrain_len):
                        self.vae.compute_vae_loss(update=True)
                    vae_is_pretrained = True

                # otherwise do the normal update (policy + vae)
                else:

                    train_stats = self.update(
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar)

                    # log
                    run_stats = [action, action_log_prob, value]
                    if train_stats is not None:
                        self.log(run_stats, train_stats, start_time)

            # clean up after update
            self.policy_storage.after_update()

    def encode_running_trajectory(self):
        """
        (Re-)Encodes (for each process) the entire current trajectory.
        Returns sample/mean/logvar and hidden state (if applicable) for the current timestep.
        :return:
        """

        # for each process, get the current batch (zero-padded obs/act/rew + length indicators)
        prev_obs, next_obs, act, rew, lens = self.vae.rollout_storage.get_running_batch(
        )

        # get embedding - will return (1+sequence_len) * batch * input_size -- includes the prior!
        all_latent_samples, all_latent_means, all_latent_logvars, all_hidden_states = self.vae.encoder(
            actions=act,
            states=next_obs,
            rewards=rew,
            hidden_state=None,
            return_prior=True)

        # get the embedding / hidden state of the current time step (need to do this since we zero-padded)
        latent_sample = (torch.stack([
            all_latent_samples[lens[i]][i] for i in range(len(lens))
        ])).detach().to(device)
        latent_mean = (torch.stack([
            all_latent_means[lens[i]][i] for i in range(len(lens))
        ])).detach().to(device)
        latent_logvar = (torch.stack([
            all_latent_logvars[lens[i]][i] for i in range(len(lens))
        ])).detach().to(device)
        hidden_state = (torch.stack([
            all_hidden_states[lens[i]][i] for i in range(len(lens))
        ])).detach().to(device)

        return latent_sample, latent_mean, latent_logvar, hidden_state

    def get_value(self, obs, latent_sample, latent_mean, latent_logvar):
        obs = utl.get_augmented_obs(self.args, obs, latent_sample, latent_mean,
                                    latent_logvar)
        return self.policy.actor_critic.get_value(obs).detach()

    def update(self, obs, latent_sample, latent_mean, latent_logvar):
        """
        Meta-update.
        Here the policy is updated for good average performance across tasks.
        :return:
        """
        # update policy (if we are not pre-training, have enough data in the vae buffer, and are not at iteration 0)
        if self.iter_idx >= self.args.pretrain_len and self.iter_idx > 0:

            # bootstrap next value prediction
            with torch.no_grad():
                next_value = self.get_value(obs=obs,
                                            latent_sample=latent_sample,
                                            latent_mean=latent_mean,
                                            latent_logvar=latent_logvar)

            # compute returns for current rollouts
            self.policy_storage.compute_returns(
                next_value,
                self.args.policy_use_gae,
                self.args.policy_gamma,
                self.args.policy_tau,
                use_proper_time_limits=self.args.use_proper_time_limits)

            # update agent (this will also call the VAE update!)
            policy_train_stats = self.policy.update(
                args=self.args,
                policy_storage=self.policy_storage,
                encoder=self.vae.encoder,
                rlloss_through_encoder=self.args.rlloss_through_encoder,
                compute_vae_loss=self.vae.compute_vae_loss)
        else:
            policy_train_stats = 0, 0, 0, 0

            # pre-train the VAE
            if self.iter_idx < self.args.pretrain_len:
                self.vae.compute_vae_loss(update=True)

        return policy_train_stats, None

    def log(self, run_stats, train_stats, start_time):
        train_stats, meta_train_stats = train_stats

        # --- visualise behaviour of policy ---

        if self.iter_idx % self.args.vis_interval == 0:
            obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None
            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None

            utl_eval.visualise_behaviour(
                args=self.args,
                policy=self.policy,
                image_folder=self.logger.full_output_folder,
                iter_idx=self.iter_idx,
                obs_rms=obs_rms,
                ret_rms=ret_rms,
                encoder=self.vae.encoder,
                reward_decoder=self.vae.reward_decoder,
                state_decoder=self.vae.state_decoder,
                task_decoder=self.vae.task_decoder,
                compute_rew_reconstruction_loss=self.vae.
                compute_rew_reconstruction_loss,
                compute_state_reconstruction_loss=self.vae.
                compute_state_reconstruction_loss,
                compute_task_reconstruction_loss=self.vae.
                compute_task_reconstruction_loss,
                compute_kl_loss=self.vae.compute_kl_loss,
            )

        # --- evaluate policy ----

        if self.iter_idx % self.args.eval_interval == 0:

            obs_rms = self.envs.venv.obs_rms if self.args.norm_obs_for_policy else None
            ret_rms = self.envs.venv.ret_rms if self.args.norm_rew_for_policy else None

            returns_per_episode = utl_eval.evaluate(args=self.args,
                                                    policy=self.policy,
                                                    obs_rms=obs_rms,
                                                    ret_rms=ret_rms,
                                                    encoder=self.vae.encoder,
                                                    iter_idx=self.iter_idx)

            # log the return avg/std across tasks (=processes)
            returns_avg = returns_per_episode.mean(dim=0)
            returns_std = returns_per_episode.std(dim=0)
            for k in range(len(returns_avg)):
                self.logger.add('return_avg_per_iter/episode_{}'.format(k + 1),
                                returns_avg[k], self.iter_idx)
                self.logger.add(
                    'return_avg_per_frame/episode_{}'.format(k + 1),
                    returns_avg[k], self.frames)
                self.logger.add('return_std_per_iter/episode_{}'.format(k + 1),
                                returns_std[k], self.iter_idx)
                self.logger.add(
                    'return_std_per_frame/episode_{}'.format(k + 1),
                    returns_std[k], self.frames)

            print(
                "Updates {}, num timesteps {}, FPS {}, {} \n Mean return (train): {:.5f} \n"
                .format(self.iter_idx, self.frames,
                        int(self.frames / (time.time() - start_time)),
                        self.vae.rollout_storage.prev_obs.shape,
                        returns_avg[-1].item()))

        # --- save models ---

        if self.iter_idx % self.args.save_interval == 0:
            save_path = os.path.join(self.logger.full_output_folder, 'models')
            if not os.path.exists(save_path):
                os.mkdir(save_path)
            torch.save(
                self.policy.actor_critic,
                os.path.join(save_path, "policy{0}.pt".format(self.iter_idx)))
            torch.save(
                self.vae.encoder,
                os.path.join(save_path, "encoder{0}.pt".format(self.iter_idx)))
            if self.vae.state_decoder is not None:
                torch.save(
                    self.vae.state_decoder,
                    os.path.join(save_path,
                                 "state_decoder{0}.pt".format(self.iter_idx)))
            if self.vae.reward_decoder is not None:
                torch.save(
                    self.vae.reward_decoder,
                    os.path.join(save_path,
                                 "reward_decoder{0}.pt".format(self.iter_idx)))
            if self.vae.task_decoder is not None:
                torch.save(
                    self.vae.task_decoder,
                    os.path.join(save_path,
                                 "task_decoder{0}.pt".format(self.iter_idx)))

            # save normalisation params of envs
            if self.args.norm_rew_for_policy:
                # save rolling mean and std
                rew_rms = self.envs.venv.ret_rms
                utl.save_obj(rew_rms, save_path,
                             "env_rew_rms{0}.pkl".format(self.iter_idx))
            if self.args.norm_obs_for_policy:
                obs_rms = self.envs.venv.obs_rms
                utl.save_obj(obs_rms, save_path,
                             "env_obs_rms{0}.pkl".format(self.iter_idx))

            # --- log some other things ---

        if self.iter_idx % self.args.log_interval == 0:

            self.logger.add('policy_losses/value_loss', train_stats[0],
                            self.iter_idx)
            self.logger.add('policy_losses/action_loss', train_stats[1],
                            self.iter_idx)
            self.logger.add('policy_losses/dist_entropy', train_stats[2],
                            self.iter_idx)
            self.logger.add('policy_losses/sum', train_stats[3], self.iter_idx)

            self.logger.add('policy/action', run_stats[0][0].float().mean(),
                            self.iter_idx)
            if hasattr(self.policy.actor_critic, 'logstd'):
                self.logger.add('policy/action_logstd',
                                self.policy.actor_critic.dist.logstd.mean(),
                                self.iter_idx)
            self.logger.add('policy/action_logprob', run_stats[1].mean(),
                            self.iter_idx)
            self.logger.add('policy/value', run_stats[2].mean(), self.iter_idx)

            self.logger.add('encoder/latent_mean',
                            torch.cat(self.policy_storage.latent_mean).mean(),
                            self.iter_idx)
            self.logger.add(
                'encoder/latent_logvar',
                torch.cat(self.policy_storage.latent_logvar).mean(),
                self.iter_idx)

            # log the average weights and gradients of all models (where applicable)
            for [model, name
                 ] in [[self.policy.actor_critic, 'policy'],
                       [self.vae.encoder, 'encoder'],
                       [self.vae.reward_decoder, 'reward_decoder'],
                       [self.vae.state_decoder, 'state_transition_decoder'],
                       [self.vae.task_decoder, 'task_decoder']]:
                if model is not None:
                    param_list = list(model.parameters())
                    param_mean = np.mean([
                        param_list[i].data.cpu().numpy().mean()
                        for i in range(len(param_list))
                    ])
                    self.logger.add('weights/{}'.format(name), param_mean,
                                    self.iter_idx)
                    if name == 'policy':
                        self.logger.add('weights/policy_std',
                                        param_list[0].data.mean(),
                                        self.iter_idx)
                    if param_list[0].grad is not None:
                        param_grad_mean = np.mean([
                            param_list[i].grad.cpu().numpy().mean()
                            for i in range(len(param_list))
                        ])
                        self.logger.add('gradients/{}'.format(name),
                                        param_grad_mean, self.iter_idx)

    def load_and_render(self, load_iter):
        #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahJoint-v0/varibad_73__15:05_17:14:07', 'models')
        #save_path = os.path.join('/ext/varibad_github/v2/varibad/logs/hfield', 'models')
        save_path = os.path.join(
            '/ext/varibad_github/v2/varibad/logs/logs_HalfCheetahBlocks-v0/varibad_73__15:05_20:20:25',
            'models')
        self.policy.actor_critic = torch.load(
            os.path.join(save_path, "policy{0}.pt".format(load_iter)))
        self.vae.encoder = torch.load(
            os.path.join(save_path, "encoder{0}.pt").format(load_iter))

        args = self.args
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        num_processes = 1
        num_episodes = 100
        num_steps = 1999

        #import pdb; pdb.set_trace()
        # initialise environments
        envs = make_vec_envs(
            env_name=args.env_name,
            seed=args.seed,
            num_processes=num_processes,  # 1
            gamma=args.policy_gamma,
            log_dir=args.agent_log_dir,
            device=device,
            allow_early_resets=False,
            episodes_per_task=self.args.max_rollouts_per_task,
            obs_rms=None,
            ret_rms=None,
        )

        # reset latent state to prior
        latent_sample, latent_mean, latent_logvar, hidden_state = self.vae.encoder.prior(
            num_processes)

        for episode_idx in range(num_episodes):
            (prev_obs_raw, prev_obs_normalised) = envs.reset()
            prev_obs_raw = prev_obs_raw.to(device)
            prev_obs_normalised = prev_obs_normalised.to(device)
            for step_idx in range(num_steps):

                with torch.no_grad():
                    _, action, _ = utl.select_action(
                        args=self.args,
                        policy=self.policy,
                        obs=prev_obs_normalised
                        if self.args.norm_obs_for_policy else prev_obs_raw,
                        latent_sample=latent_sample,
                        latent_mean=latent_mean,
                        latent_logvar=latent_logvar,
                        deterministic=True)

                # observe reward and next obs
                (next_obs_raw, next_obs_normalised), (
                    rew_raw,
                    rew_normalised), done, infos = utl.env_step(envs, action)
                # render
                envs.venv.venv.envs[0].env.env.env.env.render()

                # update the hidden state
                latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(
                    encoder=self.vae.encoder,
                    next_obs=next_obs_raw,
                    action=action,
                    reward=rew_raw,
                    done=None,
                    hidden_state=hidden_state)

                prev_obs_normalised = next_obs_normalised
                prev_obs_raw = next_obs_raw

                if done[0]:
                    break
def main():
    args = get_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    log_dir = os.path.expanduser(args.log_dir)
    eval_log_dir = log_dir + "_eval"
    utils.cleanup_log_dir(log_dir)
    utils.cleanup_log_dir(eval_log_dir)

    torch.set_num_threads(1)
    device = torch.device("cuda:0" if args.cuda else "cpu")

    envs = make_vec_envs(args.env_name,
                         args.seed,
                         args.num_processes,
                         args.gamma,
                         args.log_dir,
                         device,
                         False,
                         no_obs_norm=args.no_obs_norm)

    if args.multi_action_head:
        head_infos = envs.get_attr("head_infos")[0]
        autoregressive_maps = envs.get_attr("autoregressive_maps")[0]
        action_type_masks = torch.tensor(envs.get_attr("action_type_masks")[0],
                                         dtype=torch.float32,
                                         device=device)
        action_heads = MultiActionHeads(head_infos,
                                        autoregressive_maps,
                                        action_type_masks,
                                        input_dim=args.hidden_size)
        actor_critic = MultiHeadPolicy(envs.observation_space.shape,
                                       action_heads,
                                       use_action_masks=args.use_action_masks,
                                       base_kwargs={
                                           'recurrent': args.recurrent_policy,
                                           'recurrent_type':
                                           args.recurrent_type,
                                           'hidden_size': args.hidden_size
                                       })
    else:
        actor_critic = Policy(envs.observation_space.shape,
                              envs.action_space,
                              use_action_masks=args.use_action_masks,
                              base_kwargs={
                                  'recurrent': args.recurrent_policy,
                                  'recurrent_type': args.recurrent_type,
                                  'hidden_size': args.hidden_size
                              })
    actor_critic.to(device)

    agent = PPO(actor_critic,
                args.clip_param,
                args.ppo_epoch,
                args.num_mini_batch,
                args.value_loss_coef,
                args.entropy_coef,
                lr=args.lr,
                eps=args.eps,
                max_grad_norm=args.max_grad_norm,
                recompute_returns=args.recompute_returns,
                use_gae=args.use_gae,
                gamma=args.gamma,
                gae_lambda=args.gae_lambda)

    if args.multi_action_head:
        action_head_info = envs.get_attr("head_infos")[0]
    else:
        action_head_info = None
    rollouts = RolloutStorage(
        args.num_steps,
        args.num_processes,
        envs.observation_space.shape,
        action_head_info=action_head_info,
        action_space=envs.action_space,
        recurrent_hidden_state_size=actor_critic.recurrent_hidden_state_size,
        multi_action_head=args.multi_action_head)

    obs = envs.reset()
    if actor_critic.use_action_masks:
        action_masks = envs.env_method(
            "get_available_actions"
        )  #build in zip so it returns [head_1(all_envs), head_2(all_envs), ...]
        if args.multi_action_head:
            action_masks = list(zip(*action_masks))
            for i in range(len(rollouts.actions)):
                rollouts.action_masks[i][0].copy_(torch.tensor(
                    action_masks[i]))
        else:
            rollouts.action_masks[0].copy_(
                torch.tensor(action_masks, dtype=torch.float32, device=device))
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    episode_rewards = deque(maxlen=10)

    start = time.time()
    num_updates = int(
        args.num_env_steps) // args.num_steps // args.num_processes

    for j in range(num_updates):

        if args.use_linear_lr_decay:
            utils.update_linear_schedule(agent.optimizer, j, num_updates,
                                         args.lr)

        for step in range(args.num_steps):
            with torch.no_grad():
                if actor_critic.is_recurrent and actor_critic.base.recurrent_type == "LSTM":
                    recurrent_hidden_state_in = (
                        rollouts.recurrent_hidden_states[step],
                        rollouts.recurrent_cell_states[step])
                else:
                    recurrent_hidden_state_in = rollouts.recurrent_hidden_states[
                        step]
                if args.multi_action_head:
                    action_masks = [
                        rollouts.action_masks[i][step]
                        for i in range(len(rollouts.actions))
                    ]
                else:
                    action_masks = rollouts.action_masks[step]
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step],
                    recurrent_hidden_state_in,
                    rollouts.masks[step],
                    action_masks=action_masks)

            obs, reward, done, infos = envs.step(action)

            action_masks_info = []
            for info in infos:
                if 'episode' in info.keys():
                    episode_rewards.append(info['episode']['r'])
                if actor_critic.use_action_masks:
                    action_masks_info.append(info["available_actions"])

            if actor_critic.use_action_masks:
                if args.multi_action_head:
                    action_masks = list(zip(*action_masks_info))
                    for i in range(len(action_masks)):
                        action_masks[i] = torch.tensor(action_masks[i],
                                                       dtype=torch.float32,
                                                       device=device)
                else:
                    action_masks = torch.tensor(action_masks_info,
                                                dtype=torch.float32,
                                                device=device)
            else:
                action_masks = None

            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])

            rollouts.insert(obs,
                            recurrent_hidden_states,
                            action,
                            action_log_prob,
                            value,
                            reward,
                            masks,
                            action_masks=action_masks)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)
        rollouts.after_update()

        # save for every interval-th episode or for the last epoch
        if (j % args.save_interval == 0
                or j == num_updates - 1) and args.save_dir != "":
            save_path = os.path.join(args.save_dir, args.algo)
            try:
                os.makedirs(save_path)
            except OSError:
                pass

            torch.save([
                actor_critic,
                getattr(utils.get_vec_normalize(envs), 'obs_rms', None)
            ], os.path.join(save_path, args.env_name + args.extra_id + ".pt"))

        if j % args.log_interval == 0 and len(episode_rewards) > 1:
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            end = time.time()
            print(
                "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n"
                .format(j, total_num_steps,
                        int(total_num_steps / (end - start)),
                        len(episode_rewards), np.mean(episode_rewards),
                        np.median(episode_rewards), np.min(episode_rewards),
                        np.max(episode_rewards), dist_entropy, value_loss,
                        action_loss))
            # x = tuple(actor_critic.dist.logstd._bias.squeeze().detach().cpu().numpy())
            # print(("action std's: ["+', '.join(['%.2f']*len(x))+"]") % tuple([np.exp(a) for a in x]))

        if (args.eval_interval is not None and len(episode_rewards) > 1
                and j % args.eval_interval == 0):
            if args.no_obs_norm == False:
                obs_rms = utils.get_vec_normalize(envs).obs_rms
            else:
                obs_rms = None
Example #17
0
    def initialise_policy(self):

        # initialise rollout storage for the policy
        self.policy_storage = OnlineStorage(
            self.args,
            self.args.policy_num_steps,
            self.args.num_processes,
            self.args.obs_dim,
            self.args.act_space,
            hidden_size=self.args.aggregator_hidden_size,
            latent_dim=self.args.latent_dim,
            normalise_observations=self.args.norm_obs_for_policy,
            normalise_rewards=self.args.norm_rew_for_policy,
        )

        # initialise policy network
        input_dim = self.args.obs_dim * int(
            self.args.condition_policy_on_state)
        input_dim += (
            1 + int(not self.args.sample_embeddings)) * self.args.latent_dim

        if hasattr(self.envs.action_space, 'low'):
            action_low = self.envs.action_space.low
            action_high = self.envs.action_space.high
        else:
            action_low = action_high = None

        policy_net = Policy(
            state_dim=input_dim,
            action_space=self.args.act_space,
            init_std=self.args.policy_init_std,
            hidden_layers=self.args.policy_layers,
            activation_function=self.args.policy_activation_function,
            normalise_actions=self.args.normalise_actions,
            action_low=action_low,
            action_high=action_high,
        ).to(device)

        # initialise policy trainer
        if self.args.policy == 'a2c':
            self.policy = A2C(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                optimiser_vae=self.vae.optimiser_vae,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                alpha=self.args.a2c_alpha,
            )
        elif self.args.policy == 'ppo':
            self.policy = PPO(
                policy_net,
                self.args.policy_value_loss_coef,
                self.args.policy_entropy_coef,
                optimiser_vae=self.vae.optimiser_vae,
                lr=self.args.lr_policy,
                eps=self.args.policy_eps,
                ppo_epoch=self.args.ppo_num_epochs,
                num_mini_batch=self.args.ppo_num_minibatch,
                use_huber_loss=self.args.ppo_use_huberloss,
                use_clipped_value_loss=self.args.ppo_use_clipped_value_loss,
                clip_param=self.args.ppo_clip_param,
            )
        else:
            raise NotImplementedError
Example #18
0
def main(_seed, _config, _run):
    args = init(_seed, _config, _run, post_config=post_config)

    env_name = args.env_name

    dummy_env = make_env(env_name, render=False)

    cleanup_log_dir(args.log_dir)

    try:
        os.makedirs(args.save_dir)
    except OSError:
        pass

    torch.set_num_threads(1)

    envs = make_vec_envs(env_name, args.seed, args.num_processes, args.log_dir)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0], *obs_shape[1:])

    if args.load_saved_controller:
        best_model = "{}_best.pt".format(env_name)
        model_path = os.path.join(current_dir, "models", best_model)
        print("Loading model {}".format(best_model))
        actor_critic = torch.load(model_path)
    else:
        if args.mirror_method == MirrorMethods.net2:
            controller = SymmetricNetV2(
                *dummy_env.unwrapped.mirror_sizes,
                num_layers=6,
                hidden_size=256,
                tanh_finish=True
            )
        else:
            controller = SoftsignActor(dummy_env)
            if args.mirror_method == MirrorMethods.net:
                controller = SymmetricNet(controller, *dummy_env.unwrapped.sym_act_inds)
        actor_critic = Policy(controller)
        if args.sym_value_net:
            actor_critic.critic = SymmetricVNet(
                actor_critic.critic, controller.state_dim
            )

    mirror_function = None
    if (
        args.mirror_method == MirrorMethods.traj
        or args.mirror_method == MirrorMethods.loss
    ):
        indices = dummy_env.unwrapped.get_mirror_indices()
        mirror_function = get_mirror_function(indices)

    if args.cuda:
        actor_critic.cuda()

    agent = PPO(actor_critic, mirror_function=mirror_function, **args.ppo_params)

    rollouts = RolloutStorage(
        args.num_steps,
        args.num_processes,
        obs_shape,
        envs.action_space.shape[0],
        actor_critic.state_size,
    )
    current_obs = torch.zeros(args.num_processes, *obs_shape)

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    episode_rewards = deque(maxlen=args.num_processes)
    num_updates = int(args.num_frames) // args.num_steps // args.num_processes

    start = time.time()
    next_checkpoint = args.save_every
    max_ep_reward = float("-inf")

    logger = ConsoleCSVLogger(
        log_dir=args.experiment_dir, console_log_interval=args.log_interval
    )

    for j in range(num_updates):

        if args.lr_decay_type == "linear":
            scheduled_lr = linear_decay(j, num_updates, args.lr, final_value=0)
        elif args.lr_decay_type == "exponential":
            scheduled_lr = exponential_decay(j, 0.99, args.lr, final_value=3e-5)
        else:
            scheduled_lr = args.lr

        set_optimizer_lr(agent.optimizer, scheduled_lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step],
                    rollouts.states[step],
                    rollouts.masks[step],
                )
            cpu_actions = action.squeeze(1).cpu().numpy()

            obs, reward, done, infos = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()

            bad_masks = np.ones((args.num_processes, 1))
            for p_index, info in enumerate(infos):
                keys = info.keys()
                # This information is added by algorithms.utils.TimeLimitMask
                if "bad_transition" in keys:
                    bad_masks[p_index] = 0.0
                # This information is added by baselines.bench.Monitor
                if "episode" in keys:
                    episode_rewards.append(info["episode"]["r"])

            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.from_numpy(bad_masks)

            update_current_obs(obs)
            rollouts.insert(
                current_obs,
                states,
                action,
                action_log_prob,
                value,
                reward,
                masks,
                bad_masks,
            )

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.observations[-1], rollouts.states[-1], rollouts.masks[-1]
            ).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        frame_count = (j + 1) * args.num_steps * args.num_processes
        if (
            frame_count >= next_checkpoint or j == num_updates - 1
        ) and args.save_dir != "":
            model_name = "{}_{:d}.pt".format(env_name, int(next_checkpoint))
            next_checkpoint += args.save_every
        else:
            model_name = "{}_latest.pt".format(env_name)

        # A really ugly way to save a model to CPU
        save_model = actor_critic
        if args.cuda:
            save_model = copy.deepcopy(actor_critic).cpu()
        drive=1
        if drive:
          #print("save")
          torch.save(save_model, os.path.join("/content/gdrive/My Drive/darwin", model_name))
        torch.save(save_model, os.path.join(args.save_dir, model_name))

        if len(episode_rewards) > 1 and np.mean(episode_rewards) > max_ep_reward:
            model_name = "{}_best.pt".format(env_name)
            max_ep_reward = np.mean(episode_rewards)
            drive=1
            if drive:
              #print("max_ep_reward",max_ep_reward)
              torch.save(save_model, os.path.join("/content/gdrive/My Drive/darwin", model_name))
            torch.save(save_model, os.path.join(args.save_dir, model_name))  

        if len(episode_rewards) > 1:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            logger.log_epoch(
                {
                    "iter": j + 1,
                    "total_num_steps": total_num_steps,
                    "fps": int(total_num_steps / (end - start)),
                    "entropy": dist_entropy,
                    "value_loss": value_loss,
                    "action_loss": action_loss,
                    "stats": {"rew": episode_rewards},
                }
            )
Example #19
0
def main():

    # --- ARGUMENTS ---
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument('--model',
                        type=str,
                        default='ppo',
                        help='the model to use for training.')
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # get arguments
    args = args_pretrain_aup.get_args(rest_args)
    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(name=args.run_name,
               project=args.proj_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    wandb.config.update(args)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    # --- OBTAIN DATA FOR TRAINING R_aux ---
    print("Gathering data for R_aux Model.")

    # gather observations for pretraining the auxiliary reward function (CB-VAE)
    envs = make_vec_envs(env_name=args.env_name,
                         start_level=0,
                         num_levels=0,
                         distribution_mode=args.distribution_mode,
                         paint_vel_info=args.paint_vel_info,
                         num_processes=args.num_processes,
                         num_frame_stack=args.num_frame_stack,
                         device=device)

    # number of frames ÷ number of policy steps before update ÷ number of cpu processes
    num_batch = args.num_processes * args.policy_num_steps
    num_updates = int(args.num_frames_r_aux) // num_batch

    # create list to store env observations
    obs_data = torch.zeros(num_updates * args.policy_num_steps + 1,
                           args.num_processes, *envs.observation_space.shape)
    # reset environments
    obs = envs.reset()  # obs.shape = (n_env,C,H,W)
    obs_data[0].copy_(obs)
    obs = obs.to(device)

    for iter_idx in range(num_updates):
        # rollout policy to collect num_batch of experience and store in storage
        for step in range(args.policy_num_steps):
            # sample actions from random agent
            action = torch.randint(0, envs.action_space.n,
                                   (args.num_processes, 1))
            # observe rewards and next obs
            obs, reward, done, infos = envs.step(action)
            # store obs
            obs_data[1 + iter_idx * args.policy_num_steps + step].copy_(obs)
    # close envs
    envs.close()

    # --- TRAIN R_aux (CB-VAE) ---
    # define CB-VAE where the encoder will be used as the auxiliary reward function R_aux
    print("Training R_aux Model.")

    # create dataloader for observations gathered
    obs_data = obs_data.reshape(-1, *envs.observation_space.shape)
    sampler = BatchSampler(SubsetRandomSampler(range(obs_data.size(0))),
                           args.cb_vae_batch_size,
                           drop_last=False)

    # initialise CB-VAE
    cb_vae = CBVAE(obs_shape=envs.observation_space.shape,
                   latent_dim=args.cb_vae_latent_dim).to(device)
    # optimiser
    optimiser = torch.optim.Adam(cb_vae.parameters(),
                                 lr=args.cb_vae_learning_rate)
    # put CB-VAE into train mode
    cb_vae.train()

    measures = defaultdict(list)
    for epoch in range(args.cb_vae_epochs):
        print("Epoch: ", epoch)
        start_time = time.time()
        batch_loss = 0
        for indices in sampler:
            obs = obs_data[indices].to(device)
            # zero accumulated gradients
            cb_vae.zero_grad()
            # forward pass through CB-VAE
            recon_batch, mu, log_var = cb_vae(obs)
            # calculate loss
            loss = cb_vae_loss(recon_batch, obs, mu, log_var)
            # backpropogation: calculating gradients
            loss.backward()
            # update parameters of generator
            optimiser.step()

            # save loss per mini-batch
            batch_loss += loss.item() * obs.size(0)
        # log losses per epoch
        wandb.log({
            'cb_vae/loss': batch_loss / obs_data.size(0),
            'cb_vae/time_taken': time.time() - start_time,
            'cb_vae/epoch': epoch
        })
    indices = np.random.randint(0, obs.size(0), args.cb_vae_num_samples**2)
    measures['true_images'].append(obs[indices].detach().cpu().numpy())
    measures['recon_images'].append(
        recon_batch[indices].detach().cpu().numpy())

    # plot ground truth images
    plt.rcParams.update({'font.size': 10})
    fig, axs = plt.subplots(args.cb_vae_num_samples,
                            args.cb_vae_num_samples,
                            figsize=(20, 20))
    for i, img in enumerate(measures['true_images'][0]):
        axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow(
            img.transpose(1, 2, 0))
        axs[i // args.cb_vae_num_samples][i %
                                          args.cb_vae_num_samples].axis('off')
    wandb.log({"Ground Truth Images": wandb.Image(plt)})
    # plot reconstructed images
    fig, axs = plt.subplots(args.cb_vae_num_samples,
                            args.cb_vae_num_samples,
                            figsize=(20, 20))
    for i, img in enumerate(measures['recon_images'][0]):
        axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow(
            img.transpose(1, 2, 0))
        axs[i // args.cb_vae_num_samples][i %
                                          args.cb_vae_num_samples].axis('off')
    wandb.log({"Reconstructed Images": wandb.Image(plt)})

    # --- TRAIN Q_aux --
    # train PPO agent with value head replaced with action-value head and training on R_aux instead of the environment R
    print("Training Q_aux Model.")

    # initialise environments for training Q_aux
    envs = make_vec_envs(env_name=args.env_name,
                         start_level=0,
                         num_levels=0,
                         distribution_mode=args.distribution_mode,
                         paint_vel_info=args.paint_vel_info,
                         num_processes=args.num_processes,
                         num_frame_stack=args.num_frame_stack,
                         device=device)

    # initialise policy network
    actor_critic = QModel(obs_shape=envs.observation_space.shape,
                          action_space=envs.action_space,
                          hidden_size=args.hidden_size).to(device)

    # initialise policy trainer
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=envs.observation_space.shape,
                              action_space=envs.action_space)

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    update_start_time = time.time()
    # reset environments
    obs = envs.reset()  # obs.shape = (n_envs,C,H,W)
    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of cpu processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames_q_aux) // args.num_batch
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        # rollout policy to collect num_batch of experience and store in storage
        for step in range(args.policy_num_steps):

            with torch.no_grad():
                # sample actions from policy
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step])
                # obtain reward R_aux from encoder of CB-VAE
                r_aux, _, _ = cb_vae.encode(rollouts.obs[step])

            # observe rewards and next obs
            obs, _, done, infos = envs.step(action)

            # log episode info if episode finished
            for i, info in enumerate(infos):
                if 'episode' in info.keys():
                    episode_info_buf.append(info['episode'])
            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to policy buffer
            rollouts.insert(obs, r_aux, action, value, action_log_prob, masks)

            frames += args.num_processes

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy gradient algo
        total_loss, value_loss, action_loss, dist_entropy = policy.update(
            rollouts)

        # clean up after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:
            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            wandb.log({
                'q_aux_misc/timesteps':
                frames,
                'q_aux_misc/fps':
                fps,
                'q_aux_misc/explained_variance':
                float(ev),
                'q_aux_losses/total_loss':
                total_loss,
                'q_aux_losses/value_loss':
                value_loss,
                'q_aux_losses/action_loss':
                action_loss,
                'q_aux_losses/dist_entropy':
                dist_entropy,
                'q_aux_train/mean_episodic_return':
                utl_math.safe_mean(
                    [episode_info['r'] for episode_info in episode_info_buf]),
                'q_aux_train/mean_episodic_length':
                utl_math.safe_mean(
                    [episode_info['l'] for episode_info in episode_info_buf])
            })

    # close envs
    envs.close()

    # --- SAVE MODEL ---
    print("Saving Q_aux Model.")
    torch.save(actor_critic.state_dict(), args.q_aux_path)
Example #20
0
def main():

    # make the environments
    if args.num_envs == 1:
        env = [gym.make(args.env_name)]
    else:
        env = [gym.make(args.env_name) for i in range(args.num_envs)]

    env = MultiGym(env, render=args.render)

    n_states = env.observation_space.shape
    n_actions = env.action_space.n
    print('state shape:', n_states, 'actions:', n_actions)

    policy = ConvPolicy(n_actions).to(device)
    optimizer = optim.RMSprop(policy.parameters(), lr=args.lr)

    if args.algo == 'ppo':
        sys.path.append('../')
        from algorithms.ppo import PPO
        update_algo = PPO(policy=policy,
                          optimizer=optimizer,
                          num_steps=args.num_steps,
                          num_envs=args.num_envs,
                          state_size=(4, 105, 80),
                          entropy_coef=args.entropy,
                          gamma=args.gamma,
                          device=device,
                          epochs=args.ppo_epochs)
    else:
        sys.path.append('../')
        from algorithms.a2c import A2C
        update_algo = A2C(policy=policy,
                          optimizer=optimizer,
                          num_steps=args.num_steps,
                          num_envs=args.num_envs,
                          state_size=(4, 105, 80),
                          entropy_coef=args.entropy,
                          gamma=args.gamma,
                          device=device)

    end_rewards = []

    try:
        print('starting episodes')
        idx = 0
        d = False
        reward_sum = np.zeros((args.num_envs))
        restart = True
        frame = env.reset()
        mask = torch.ones(args.num_envs)
        all_start = time.time()

        for update_idx in range(args.num_updates):
            update_algo.policy.train()

            # stack the frames
            s = train_state_proc.proc_state(frame, mask=mask)

            # insert state before getting actions
            update_algo.states[0].copy_(s)

            start = time.time()
            for step in range(args.num_steps):

                with torch.no_grad():
                    # get probability dist and values
                    p, v = update_algo.policy(update_algo.states[step])
                    a = Categorical(p).sample()

                # take action get response
                frame, r, d = env.step(
                    a.cpu().numpy() if args.num_envs > 1 else [a.item()])
                s = train_state_proc.proc_state(frame, mask)

                update_algo.insert_experience(step=step,
                                              s=s,
                                              a=a,
                                              v=v,
                                              r=r,
                                              d=d)

                mask = torch.tensor(1. - d).float()
                reward_sum = (reward_sum + r)

                # if any episode finished append episode reward to list
                if d.any():
                    end_rewards.extend(reward_sum[d])

                # reset any rewards that finished
                reward_sum = reward_sum * mask.numpy()

                idx += 1

            with torch.no_grad():
                _, next_val = update_algo.policy(update_algo.states[-1])

            update_algo.update(next_val.view(1, args.num_envs).to(device),
                               next_mask=mask.to(device))

            if args.lr_decay:
                for params in update_algo.optimizer.param_groups:
                    params['lr'] = (
                        lr_min + 0.5 * (args.lr - lr_min) *
                        (1 + np.cos(np.pi * idx / args.num_updates)))

            # update every so often by displaying results in term
            if (update_idx % args.log_interval
                    == 0) and (len(end_rewards) > 0):
                total_steps = (idx + 1) * args.num_envs * args.num_steps
                end = time.time()
                print(end_rewards[-10:])
                print('Updates {}\t  Time: {:.4f} \t FPS: {}'.format(
                    update_idx, end - start,
                    int(total_steps / (end - all_start))))
                print(
                    'Mean Episode Rewards: {:.2f} \t Min/Max Current Rewards: {}/{}'
                    .format(np.mean(end_rewards[-10:]), reward_sum.min(),
                            reward_sum.max()))

    except KeyboardInterrupt:
        pass

    torch.save(
        update_algo.policy.state_dict(),
        '../model_weights/{}_{}_conv.pth'.format(args.env_name, args.algo))

    import pandas as pd

    out_dict = {'avg_end_rewards': end_rewards}
    out_log = pd.DataFrame(out_dict)
    out_log.to_csv('../logs/{}_{}_rewards.csv'.format(args.env_name,
                                                      args.algo),
                   index=False)

    out_dict = {
        'actor losses': update_algo.actor_losses,
        'critic losses': update_algo.critic_losses,
        'entropy': update_algo.entropy_logs
    }
    out_log = pd.DataFrame(out_dict)
    out_log.to_csv('../logs/{}_{}_training_behavior.csv'.format(
        args.env_name, args.algo),
                   index=False)

    plt.plot(end_rewards)
    plt.show()
Example #21
0
 
 def model_fn(obs):
     x = tf.layers.conv2d(obs, 32, 8, 4, activation=tf.nn.relu)
     x = tf.layers.conv2d(x, 64, 4, 2, activation=tf.nn.relu)
     x = tf.layers.conv2d(x, 64, 3, 1, activation=tf.nn.relu)
     x = tf.contrib.layers.flatten(x)
     x = tf.layers.dense(x, 512, activation=tf.nn.relu)
 
     logit_action_probability = tf.layers.dense(
             x, action_space,
             kernel_initializer=tf.truncated_normal_initializer(0.0, 0.01))
     state_value = tf.squeeze(tf.layers.dense(
             x, 1, kernel_initializer=tf.truncated_normal_initializer()))
     return logit_action_probability, state_value
 
 ppo = PPO(action_space, obs_fn, model_fn, train_epoch=5, batch_size=32)
 
 env = Raiden2(6666, num_envs=8, with_stack=False)
 env_ids, states, rewards, dones = env.start()
 env_states = defaultdict(partial(deque, maxlen=frame_stack))
 
 nth_trajectory = 0
 while True:
     nth_trajectory += 1
     for _ in tqdm(range(explore_steps)):
         sts = []
         for env_id, state in zip(env_ids, states):
             st = np.zeros((size, size, frame_stack), dtype=np.float32)
             im = cv2.resize(rgb2gray(state), (size, size))
             env_states[env_id].append(im)
             while len(env_states[env_id]) < frame_stack:
n_actions = env.action_space.n
print('states:', n_states, 'actions:', n_actions)

policy = GRUPolicy(n_states[0], n_actions, args.hid_size, args.num_steps,
                   args.num_envs).to(device)
optimizer = optim.RMSprop(policy.parameters(), lr=args.lr, eps=1e-5)

if args.algo == 'ppo':
    sys.path.append('../')
    from algorithms.ppo import PPO
    update_algo = PPO(policy=policy,
                      optimizer=optimizer,
                      num_steps=args.num_steps,
                      num_envs=args.num_envs,
                      state_size=n_states,
                      entropy_coef=args.entropy,
                      gamma=args.gamma,
                      device=device,
                      recurrent=True,
                      rnn_size=args.hid_size,
                      epochs=args.ppo_epochs,
                      batch_size=args.batch_size)
else:
    sys.path.append('../')
    from algorithms.a2c import A2C
    update_algo = A2C(policy=policy,
                      optimizer=optimizer,
                      num_steps=args.num_steps,
                      num_envs=args.num_envs,
                      state_size=n_states,
                      entropy_coef=args.entropy,
                      gamma=args.gamma,