def main():
    # setup parameters
    args = SimpleNamespace(
        env_module="environments",
        env_name="TargetEnv-v0",
        device="cuda:0" if torch.cuda.is_available() else "cpu",
        num_parallel=100,
        vae_path="models/",
        frame_skip=1,
        seed=16,
        load_saved_model=False,
    )

    args.num_parallel *= args.frame_skip
    env = make_gym_environment(args)

    # env parameters
    args.action_size = env.action_space.shape[0]
    args.observation_size = env.observation_space.shape[0]

    # other configs
    args.save_path = os.path.join(current_dir, "con_" + args.env_name + ".pt")

    # sampling parameters
    args.num_frames = 10e7
    args.num_steps_per_rollout = env.unwrapped.max_timestep
    args.num_updates = int(args.num_frames / args.num_parallel /
                           args.num_steps_per_rollout)

    # learning parameters
    args.lr = 3e-5
    args.final_lr = 1e-5
    args.eps = 1e-5
    args.lr_decay_type = "exponential"
    args.mini_batch_size = 1000
    args.num_mini_batch = (args.num_parallel * args.num_steps_per_rollout //
                           args.mini_batch_size)

    # ppo parameters
    use_gae = True
    entropy_coef = 0.0
    value_loss_coef = 1.0
    ppo_epoch = 10
    gamma = 0.99
    gae_lambda = 0.95
    clip_param = 0.2
    max_grad_norm = 1.0

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    obs_shape = env.observation_space.shape
    obs_shape = (obs_shape[0], *obs_shape[1:])

    if args.load_saved_model:
        actor_critic = torch.load(args.save_path, map_location=args.device)
        print("Loading model:", args.save_path)
    else:
        controller = PoseVAEController(env)
        actor_critic = PoseVAEPolicy(controller)

    actor_critic = actor_critic.to(args.device)
    actor_critic.env_info = {"frame_skip": args.frame_skip}

    agent = PPO(
        actor_critic,
        clip_param,
        ppo_epoch,
        args.num_mini_batch,
        value_loss_coef,
        entropy_coef,
        lr=args.lr,
        eps=args.eps,
        max_grad_norm=max_grad_norm,
    )

    rollouts = RolloutStorage(
        args.num_steps_per_rollout,
        args.num_parallel,
        obs_shape,
        args.action_size,
        actor_critic.state_size,
    )
    obs = env.reset()
    rollouts.observations[0].copy_(obs)
    rollouts.to(args.device)

    log_path = os.path.join(current_dir,
                            "log_ppo_progress-{}".format(args.env_name))
    logger = StatsLogger(csv_path=log_path)

    for update in range(args.num_updates):

        ep_info = {"reward": []}
        ep_reward = 0

        if args.lr_decay_type == "linear":
            update_linear_schedule(agent.optimizer, update, args.num_updates,
                                   args.lr, args.final_lr)
        elif args.lr_decay_type == "exponential":
            update_exponential_schedule(agent.optimizer, update, 0.99, args.lr,
                                        args.final_lr)

        for step in range(args.num_steps_per_rollout):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.observations[step])

            obs, reward, done, info = env.step(action)
            ep_reward += reward

            end_of_rollout = info.get("reset")
            masks = (~done).float()
            bad_masks = (~(done * end_of_rollout)).float()

            if done.any():
                ep_info["reward"].append(ep_reward[done].clone())
                ep_reward *= (~done).float()  # zero out the dones
                reset_indices = env.parallel_ind_buf.masked_select(
                    done.squeeze())
                obs = env.reset(reset_indices)

            if end_of_rollout:
                obs = env.reset()

            rollouts.insert(obs, action, action_log_prob, value, reward, masks,
                            bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.observations[-1]).detach()

        rollouts.compute_returns(next_value, use_gae, gamma, gae_lambda)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        torch.save(copy.deepcopy(actor_critic).cpu(), args.save_path)

        ep_info["reward"] = torch.cat(ep_info["reward"])
        logger.log_stats(
            args,
            {
                "update": update,
                "ep_info": ep_info,
                "dist_entropy": dist_entropy,
                "value_loss": value_loss,
                "action_loss": action_loss,
            },
        )
예제 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument('--model',
                        type=str,
                        default='ppo',
                        help='the model to use for training. {ppo, ppo_aup}')
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # --- ARGUMENTS ---

    if model == 'ppo':
        args = args_ppo.get_args(rest_args)
    elif model == 'ppo_aup':
        args = args_ppo_aup.get_args(rest_args)
    else:
        raise NotImplementedError

    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model

    # warnings
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError(
                'If you want fully deterministic code, run it with num_processes=1.'
                'Warning: This will slow things down and might break A2C if '
                'policy_num_steps < env._max_episode_steps.')

    # --- TRAINING ---
    print("Setting up wandb logging.")

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(project=args.proj_name,
               name=args.run_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    # make directory for saving models
    save_dir = os.path.join(wandb.run.dir, 'models')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    print("Setting up Environments.")
    # initialise environments for training
    train_envs = make_vec_envs(env_name=args.env_name,
                               start_level=args.train_start_level,
                               num_levels=args.train_num_levels,
                               distribution_mode=args.distribution_mode,
                               paint_vel_info=args.paint_vel_info,
                               num_processes=args.num_processes,
                               num_frame_stack=args.num_frame_stack,
                               device=device)
    # initialise environments for evaluation
    eval_envs = make_vec_envs(env_name=args.env_name,
                              start_level=0,
                              num_levels=0,
                              distribution_mode=args.distribution_mode,
                              paint_vel_info=args.paint_vel_info,
                              num_processes=args.num_processes,
                              num_frame_stack=args.num_frame_stack,
                              device=device)
    _ = eval_envs.reset()

    print("Setting up Actor-Critic model and Training algorithm.")
    # initialise policy network
    actor_critic = ACModel(obs_shape=train_envs.observation_space.shape,
                           action_space=train_envs.action_space,
                           hidden_size=args.hidden_size).to(device)

    # initialise policy training algorithm
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy training algorithm
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=train_envs.observation_space.shape,
                              action_space=train_envs.action_space)

    # initialise Q_aux function(s) for AUP
    if args.use_aup:
        print("Initialising Q_aux models.")
        q_aux = [
            QModel(obs_shape=train_envs.observation_space.shape,
                   action_space=train_envs.action_space,
                   hidden_size=args.hidden_size).to(device)
            for _ in range(args.num_q_aux)
        ]
        if args.num_q_aux == 1:
            # load weights to model
            path = args.q_aux_dir + "0.pt"
            q_aux[0].load_state_dict(torch.load(path))
            q_aux[0].eval()
        else:
            # get max number of q_aux functions to choose from
            args.max_num_q_aux = os.listdir(args.q_aux_dir)
            q_aux_models = random.sample(list(range(0, args.max_num_q_aux)),
                                         args.num_q_aux)
            # load weights to models
            for i, model in enumerate(q_aux):
                path = args.q_aux_dir + str(q_aux_models[i]) + ".pt"
                model.load_state_dict(torch.load(path))
                model.eval()

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    # update wandb args
    wandb.config.update(args)

    update_start_time = time.time()
    # reset environments
    obs = train_envs.reset()  # obs.shape = (num_processes,C,H,W)
    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames) // args.num_batch

    # define AUP coefficient
    if args.use_aup:
        aup_coef = args.aup_coef_start
        aup_linear_increase_val = math.exp(
            math.log(args.aup_coef_end / args.aup_coef_start) /
            args.num_updates)

    print("Training beginning.")
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        if args.use_aup:
            aup_measures = defaultdict(list)

        # rollout policy to collect num_batch of experience and place in storage
        for step in range(args.policy_num_steps):

            # sample actions from policy
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step])

            # observe rewards and next obs
            obs, reward, done, infos = train_envs.step(action)

            # calculate AUP reward
            if args.use_aup:
                intrinsic_reward = torch.zeros_like(reward)
                with torch.no_grad():
                    for model in q_aux:
                        # get action-values
                        action_values = model.get_action_value(
                            rollouts.obs[step])
                        # get action-value for action taken
                        action_value = torch.sum(
                            action_values * torch.nn.functional.one_hot(
                                action,
                                num_classes=train_envs.action_space.n).squeeze(
                                    dim=1),
                            dim=1)
                        # calculate the penalty
                        intrinsic_reward += torch.abs(
                            action_value.unsqueeze(dim=1) -
                            action_values[:, 4].unsqueeze(dim=1))
                intrinsic_reward /= args.num_q_aux
                # add intrinsic reward to the extrinsic reward
                reward -= aup_coef * intrinsic_reward
                # log the intrinsic reward from the first env.
                aup_measures['intrinsic_reward'].append(aup_coef *
                                                        intrinsic_reward[0, 0])
                if done[0] and infos[0]['prev_level_complete'] == 1:
                    aup_measures['episode_complete'].append(2)
                elif done[0] and infos[0]['prev_level_complete'] == 0:
                    aup_measures['episode_complete'].append(1)
                else:
                    aup_measures['episode_complete'].append(0)

            # log episode info if episode finished
            for info in infos:
                if 'episode' in info.keys():
                    episode_info_buf.append(info['episode'])

            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to storage
            rollouts.insert(obs, reward, action, value, action_log_prob, masks)

            frames += args.num_processes

        # linearly increase aup coefficient after every update
        if args.use_aup:
            aup_coef *= aup_linear_increase_val

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy training algorithm
        total_loss, value_loss, action_loss, dist_entropy = policy.update(
            rollouts)

        # clean up storage after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:

            # --- EVALUATION ---
            eval_episode_info_buf = utl_eval.evaluate(
                eval_envs=eval_envs, actor_critic=actor_critic, device=device)

            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # calculates whether the value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            if args.use_aup:
                step = frames - args.num_processes * args.policy_num_steps
                for i in range(args.policy_num_steps):
                    wandb.log(
                        {
                            'aup/intrinsic_reward':
                            aup_measures['intrinsic_reward'][i],
                            'aup/episode_complete':
                            aup_measures['episode_complete'][i]
                        },
                        step=step)
                    step += args.num_processes

            wandb.log(
                {
                    'misc/timesteps':
                    frames,
                    'misc/fps':
                    fps,
                    'misc/explained_variance':
                    float(ev),
                    'losses/total_loss':
                    total_loss,
                    'losses/value_loss':
                    value_loss,
                    'losses/action_loss':
                    action_loss,
                    'losses/dist_entropy':
                    dist_entropy,
                    'train/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r'] for episode_info in episode_info_buf
                    ]),
                    'train/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l'] for episode_info in episode_info_buf
                    ]),
                    'eval/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in eval_episode_info_buf
                    ]),
                    'eval/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in eval_episode_info_buf
                    ])
                },
                step=frames)

        # --- SAVE MODEL ---

        # save for every interval-th episode or for the last epoch
        if iter_idx != 0 and (iter_idx % args.save_interval == 0
                              or iter_idx == args.num_updates - 1):
            print("Saving Actor-Critic Model.")
            torch.save(actor_critic.state_dict(),
                       os.path.join(save_dir, "policy{0}.pt".format(iter_idx)))

    # close envs
    train_envs.close()
    eval_envs.close()

    # --- TEST ---

    if args.test:
        print("Testing beginning.")
        episodic_return = utl_test.test(args=args,
                                        actor_critic=actor_critic,
                                        device=device)

        # save returns from train and test levels to analyse using interactive mode
        train_levels = torch.arange(
            args.train_start_level,
            args.train_start_level + args.train_num_levels)
        for i, level in enumerate(train_levels):
            wandb.log({
                'test/train_levels': level,
                'test/train_returns': episodic_return[0][i]
            })
        test_levels = torch.arange(
            args.test_start_level,
            args.test_start_level + args.test_num_levels)
        for i, level in enumerate(test_levels):
            wandb.log({
                'test/test_levels': level,
                'test/test_returns': episodic_return[1][i]
            })
        # log returns from test envs
        wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[0])
        wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[1])
예제 #3
0
파일: pretrain.py 프로젝트: udeepam/aup
def main():

    # --- ARGUMENTS ---
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument('--model',
                        type=str,
                        default='ppo',
                        help='the model to use for training.')
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # get arguments
    args = args_pretrain_aup.get_args(rest_args)
    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(name=args.run_name,
               project=args.proj_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    wandb.config.update(args)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    # --- OBTAIN DATA FOR TRAINING R_aux ---
    print("Gathering data for R_aux Model.")

    # gather observations for pretraining the auxiliary reward function (CB-VAE)
    envs = make_vec_envs(env_name=args.env_name,
                         start_level=0,
                         num_levels=0,
                         distribution_mode=args.distribution_mode,
                         paint_vel_info=args.paint_vel_info,
                         num_processes=args.num_processes,
                         num_frame_stack=args.num_frame_stack,
                         device=device)

    # number of frames ÷ number of policy steps before update ÷ number of cpu processes
    num_batch = args.num_processes * args.policy_num_steps
    num_updates = int(args.num_frames_r_aux) // num_batch

    # create list to store env observations
    obs_data = torch.zeros(num_updates * args.policy_num_steps + 1,
                           args.num_processes, *envs.observation_space.shape)
    # reset environments
    obs = envs.reset()  # obs.shape = (n_env,C,H,W)
    obs_data[0].copy_(obs)
    obs = obs.to(device)

    for iter_idx in range(num_updates):
        # rollout policy to collect num_batch of experience and store in storage
        for step in range(args.policy_num_steps):
            # sample actions from random agent
            action = torch.randint(0, envs.action_space.n,
                                   (args.num_processes, 1))
            # observe rewards and next obs
            obs, reward, done, infos = envs.step(action)
            # store obs
            obs_data[1 + iter_idx * args.policy_num_steps + step].copy_(obs)
    # close envs
    envs.close()

    # --- TRAIN R_aux (CB-VAE) ---
    # define CB-VAE where the encoder will be used as the auxiliary reward function R_aux
    print("Training R_aux Model.")

    # create dataloader for observations gathered
    obs_data = obs_data.reshape(-1, *envs.observation_space.shape)
    sampler = BatchSampler(SubsetRandomSampler(range(obs_data.size(0))),
                           args.cb_vae_batch_size,
                           drop_last=False)

    # initialise CB-VAE
    cb_vae = CBVAE(obs_shape=envs.observation_space.shape,
                   latent_dim=args.cb_vae_latent_dim).to(device)
    # optimiser
    optimiser = torch.optim.Adam(cb_vae.parameters(),
                                 lr=args.cb_vae_learning_rate)
    # put CB-VAE into train mode
    cb_vae.train()

    measures = defaultdict(list)
    for epoch in range(args.cb_vae_epochs):
        print("Epoch: ", epoch)
        start_time = time.time()
        batch_loss = 0
        for indices in sampler:
            obs = obs_data[indices].to(device)
            # zero accumulated gradients
            cb_vae.zero_grad()
            # forward pass through CB-VAE
            recon_batch, mu, log_var = cb_vae(obs)
            # calculate loss
            loss = cb_vae_loss(recon_batch, obs, mu, log_var)
            # backpropogation: calculating gradients
            loss.backward()
            # update parameters of generator
            optimiser.step()

            # save loss per mini-batch
            batch_loss += loss.item() * obs.size(0)
        # log losses per epoch
        wandb.log({
            'cb_vae/loss': batch_loss / obs_data.size(0),
            'cb_vae/time_taken': time.time() - start_time,
            'cb_vae/epoch': epoch
        })
    indices = np.random.randint(0, obs.size(0), args.cb_vae_num_samples**2)
    measures['true_images'].append(obs[indices].detach().cpu().numpy())
    measures['recon_images'].append(
        recon_batch[indices].detach().cpu().numpy())

    # plot ground truth images
    plt.rcParams.update({'font.size': 10})
    fig, axs = plt.subplots(args.cb_vae_num_samples,
                            args.cb_vae_num_samples,
                            figsize=(20, 20))
    for i, img in enumerate(measures['true_images'][0]):
        axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow(
            img.transpose(1, 2, 0))
        axs[i // args.cb_vae_num_samples][i %
                                          args.cb_vae_num_samples].axis('off')
    wandb.log({"Ground Truth Images": wandb.Image(plt)})
    # plot reconstructed images
    fig, axs = plt.subplots(args.cb_vae_num_samples,
                            args.cb_vae_num_samples,
                            figsize=(20, 20))
    for i, img in enumerate(measures['recon_images'][0]):
        axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow(
            img.transpose(1, 2, 0))
        axs[i // args.cb_vae_num_samples][i %
                                          args.cb_vae_num_samples].axis('off')
    wandb.log({"Reconstructed Images": wandb.Image(plt)})

    # --- TRAIN Q_aux --
    # train PPO agent with value head replaced with action-value head and training on R_aux instead of the environment R
    print("Training Q_aux Model.")

    # initialise environments for training Q_aux
    envs = make_vec_envs(env_name=args.env_name,
                         start_level=0,
                         num_levels=0,
                         distribution_mode=args.distribution_mode,
                         paint_vel_info=args.paint_vel_info,
                         num_processes=args.num_processes,
                         num_frame_stack=args.num_frame_stack,
                         device=device)

    # initialise policy network
    actor_critic = QModel(obs_shape=envs.observation_space.shape,
                          action_space=envs.action_space,
                          hidden_size=args.hidden_size).to(device)

    # initialise policy trainer
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=envs.observation_space.shape,
                              action_space=envs.action_space)

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    update_start_time = time.time()
    # reset environments
    obs = envs.reset()  # obs.shape = (n_envs,C,H,W)
    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of cpu processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames_q_aux) // args.num_batch
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        # rollout policy to collect num_batch of experience and store in storage
        for step in range(args.policy_num_steps):

            with torch.no_grad():
                # sample actions from policy
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step])
                # obtain reward R_aux from encoder of CB-VAE
                r_aux, _, _ = cb_vae.encode(rollouts.obs[step])

            # observe rewards and next obs
            obs, _, done, infos = envs.step(action)

            # log episode info if episode finished
            for i, info in enumerate(infos):
                if 'episode' in info.keys():
                    episode_info_buf.append(info['episode'])
            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to policy buffer
            rollouts.insert(obs, r_aux, action, value, action_log_prob, masks)

            frames += args.num_processes

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy gradient algo
        total_loss, value_loss, action_loss, dist_entropy = policy.update(
            rollouts)

        # clean up after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:
            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            wandb.log({
                'q_aux_misc/timesteps':
                frames,
                'q_aux_misc/fps':
                fps,
                'q_aux_misc/explained_variance':
                float(ev),
                'q_aux_losses/total_loss':
                total_loss,
                'q_aux_losses/value_loss':
                value_loss,
                'q_aux_losses/action_loss':
                action_loss,
                'q_aux_losses/dist_entropy':
                dist_entropy,
                'q_aux_train/mean_episodic_return':
                utl_math.safe_mean(
                    [episode_info['r'] for episode_info in episode_info_buf]),
                'q_aux_train/mean_episodic_length':
                utl_math.safe_mean(
                    [episode_info['l'] for episode_info in episode_info_buf])
            })

    # close envs
    envs.close()

    # --- SAVE MODEL ---
    print("Saving Q_aux Model.")
    torch.save(actor_critic.state_dict(), args.q_aux_path)
예제 #4
0
파일: main.py 프로젝트: udeepam/aldm
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument(
        '--model',
        type=str,
        default='ppo',
        help='the model to use for training. {ppo, ibac, ibac_sni, dist_match}'
    )
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # --- ARGUMENTS ---

    if model == 'ppo':
        args = args_ppo.get_args(rest_args)
    elif model == 'ibac':
        args = args_ibac.get_args(rest_args)
    elif model == 'ibac_sni':
        args = args_ibac_sni.get_args(rest_args)
    elif model == 'dist_match':
        args = args_dist_match.get_args(rest_args)
    else:
        raise NotImplementedError

    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model
    args.num_train_envs = args.num_processes - args.num_val_envs if args.num_val_envs > 0 else args.num_processes

    # warnings
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError(
                'If you want fully deterministic code, run it with num_processes=1.'
                'Warning: This will slow things down and might break A2C if '
                'policy_num_steps < env._max_episode_steps.')

    elif args.num_val_envs > 0 and (args.num_val_envs >= args.num_processes
                                    or not args.percentage_levels_train < 1.0):
        raise ValueError(
            'If --args.num_val_envs>0 then you must also have'
            '--num_val_envs < --num_processes and  0 < --percentage_levels_train < 1.'
        )

    elif args.num_val_envs > 0 and not args.use_dist_matching and args.dist_matching_coef != 0:
        raise ValueError(
            'If --num_val_envs>0 and --use_dist_matching=False then you must also have'
            '--dist_matching_coef=0.')

    elif args.use_dist_matching and not args.num_val_envs > 0:
        raise ValueError(
            'If --use_dist_matching=True then you must also have'
            '0 < --num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.'
        )

    elif args.analyse_rep and not args.use_bottleneck:
        raise ValueError('If --analyse_rep=True then you must also have'
                         '--use_bottleneck=True.')

    # --- TRAINING ---
    print("Setting up wandb logging.")

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(project=args.proj_name,
               name=args.run_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    # make directory for saving models
    save_dir = os.path.join(wandb.run.dir, 'models')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    # initialise environments for training
    print("Setting up Environments.")
    if args.num_val_envs > 0:
        train_num_levels = int(args.train_num_levels *
                               args.percentage_levels_train)
        val_start_level = args.train_start_level + train_num_levels
        val_num_levels = args.train_num_levels - train_num_levels
        train_envs = make_vec_envs(env_name=args.env_name,
                                   start_level=args.train_start_level,
                                   num_levels=train_num_levels,
                                   distribution_mode=args.distribution_mode,
                                   paint_vel_info=args.paint_vel_info,
                                   num_processes=args.num_train_envs,
                                   num_frame_stack=args.num_frame_stack,
                                   device=device)
        val_envs = make_vec_envs(env_name=args.env_name,
                                 start_level=val_start_level,
                                 num_levels=val_num_levels,
                                 distribution_mode=args.distribution_mode,
                                 paint_vel_info=args.paint_vel_info,
                                 num_processes=args.num_val_envs,
                                 num_frame_stack=args.num_frame_stack,
                                 device=device)
    else:
        train_envs = make_vec_envs(env_name=args.env_name,
                                   start_level=args.train_start_level,
                                   num_levels=args.train_num_levels,
                                   distribution_mode=args.distribution_mode,
                                   paint_vel_info=args.paint_vel_info,
                                   num_processes=args.num_processes,
                                   num_frame_stack=args.num_frame_stack,
                                   device=device)
    # initialise environments for evaluation
    eval_envs = make_vec_envs(env_name=args.env_name,
                              start_level=0,
                              num_levels=0,
                              distribution_mode=args.distribution_mode,
                              paint_vel_info=args.paint_vel_info,
                              num_processes=args.num_processes,
                              num_frame_stack=args.num_frame_stack,
                              device=device)
    _ = eval_envs.reset()
    # initialise environments for analysing the representation
    if args.analyse_rep:
        analyse_rep_train1_envs, analyse_rep_train2_envs, analyse_rep_val_envs, analyse_rep_test_envs = make_rep_analysis_envs(
            args, device)

    print("Setting up Actor-Critic model and Training algorithm.")
    # initialise policy network
    actor_critic = ACModel(obs_shape=train_envs.observation_space.shape,
                           action_space=train_envs.action_space,
                           hidden_size=args.hidden_size,
                           use_bottleneck=args.use_bottleneck,
                           sni_type=args.sni_type).to(device)

    # initialise policy training algorithm
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps,
                     vib_coef=args.vib_coef,
                     sni_coef=args.sni_coef,
                     use_dist_matching=args.use_dist_matching,
                     dist_matching_loss=args.dist_matching_loss,
                     dist_matching_coef=args.dist_matching_coef,
                     num_train_envs=args.num_train_envs,
                     num_val_envs=args.num_val_envs)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy training algorithm
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=train_envs.observation_space.shape,
                              action_space=train_envs.action_space)

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    # update wandb args
    wandb.config.update(args)
    # wandb.watch(actor_critic, log="all") # to log gradients of actor-critic network

    update_start_time = time.time()

    # reset environments
    if args.num_val_envs > 0:
        obs = torch.cat([train_envs.reset(),
                         val_envs.reset()])  # obs.shape = (n_envs,C,H,W)
    else:
        obs = train_envs.reset()  # obs.shape = (n_envs,C,H,W)

    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    train_episode_info_buf = deque(maxlen=10)
    val_episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames) // args.num_batch
    print("Training beginning.")
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        # rollout policy to collect num_batch of experience and place in storage
        for step in range(args.policy_num_steps):

            # sample actions from policy
            with torch.no_grad():
                value, action, action_log_prob, _ = actor_critic.act(
                    rollouts.obs[step])

            # observe rewards and next obs
            if args.num_val_envs > 0:
                obs, reward, done, infos = train_envs.step(
                    action[:args.num_train_envs, :])
                val_obs, val_reward, val_done, val_infos = val_envs.step(
                    action[args.num_train_envs:, :])
                obs = torch.cat([obs, val_obs])
                reward = torch.cat([reward, val_reward])
                done, val_done = list(done), list(val_done)
                done.extend(val_done)
                infos.extend(val_infos)
            else:
                obs, reward, done, infos = train_envs.step(action)

            # log episode info if episode finished
            for i, info in enumerate(infos):
                if i < args.num_train_envs and 'episode' in info.keys():
                    train_episode_info_buf.append(info['episode'])
                elif i >= args.num_train_envs and 'episode' in info.keys():
                    val_episode_info_buf.append(info['episode'])

            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to storage
            rollouts.insert(obs, reward, action, value, action_log_prob, masks)

            frames += args.num_processes

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy gradient algo
        total_loss, value_loss, action_loss, dist_entropy, vib_kl, dist_matching_loss = policy.update(
            rollouts)

        # clean up storage after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:

            # --- EVALUATION ---
            eval_episode_info_buf = utl_eval.evaluate(
                eval_envs=eval_envs, actor_critic=actor_critic, device=device)

            # --- ANALYSE REPRESENTATION ---
            if args.analyse_rep:
                rep_measures = utl_rep.analyse_rep(
                    args=args,
                    train1_envs=analyse_rep_train1_envs,
                    train2_envs=analyse_rep_train2_envs,
                    val_envs=analyse_rep_val_envs,
                    test_envs=analyse_rep_test_envs,
                    actor_critic=actor_critic,
                    device=device)

            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            wandb.log(
                {
                    'misc/timesteps':
                    frames,
                    'misc/fps':
                    fps,
                    'misc/explained_variance':
                    float(ev),
                    'losses/total_loss':
                    total_loss,
                    'losses/value_loss':
                    value_loss,
                    'losses/action_loss':
                    action_loss,
                    'losses/dist_entropy':
                    dist_entropy,
                    'train/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in train_episode_info_buf
                    ]),
                    'train/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in train_episode_info_buf
                    ]),
                    'eval/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in eval_episode_info_buf
                    ]),
                    'eval/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in eval_episode_info_buf
                    ])
                },
                step=iter_idx)
            if args.use_bottleneck:
                wandb.log({'losses/vib_kl': vib_kl}, step=iter_idx)
            if args.num_val_envs > 0:
                wandb.log(
                    {
                        'losses/dist_matching_loss':
                        dist_matching_loss,
                        'val/mean_episodic_return':
                        utl_math.safe_mean([
                            episode_info['r']
                            for episode_info in val_episode_info_buf
                        ]),
                        'val/mean_episodic_length':
                        utl_math.safe_mean([
                            episode_info['l']
                            for episode_info in val_episode_info_buf
                        ])
                    },
                    step=iter_idx)
            if args.analyse_rep:
                wandb.log(
                    {
                        "analysis/" + key: val
                        for key, val in rep_measures.items()
                    },
                    step=iter_idx)

        # --- SAVE MODEL ---

        # save for every interval-th episode or for the last epoch
        if iter_idx != 0 and (iter_idx % args.save_interval == 0
                              or iter_idx == args.num_updates - 1):
            print("Saving Actor-Critic Model.")
            torch.save(actor_critic.state_dict(),
                       os.path.join(save_dir, "policy{0}.pt".format(iter_idx)))

    # close envs
    train_envs.close()
    eval_envs.close()

    # --- TEST ---

    if args.test:
        print("Testing beginning.")
        episodic_return, latents_z = utl_test.test(args=args,
                                                   actor_critic=actor_critic,
                                                   device=device)

        # save returns from train and test levels to analyse using interactive mode
        train_levels = torch.arange(
            args.train_start_level,
            args.train_start_level + args.train_num_levels)
        for i, level in enumerate(train_levels):
            wandb.log({
                'test/train_levels': level,
                'test/train_returns': episodic_return[0][i]
            })
        test_levels = torch.arange(
            args.test_start_level,
            args.test_start_level + args.test_num_levels)
        for i, level in enumerate(test_levels):
            wandb.log({
                'test/test_levels': level,
                'test/test_returns': episodic_return[1][i]
            })
        # log returns from test envs
        wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[0])
        wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[1])

        # plot latent representation
        if args.plot_pca:
            print("Plotting PCA of Latent Representation.")
            utl_rep.pca(args, latents_z)