def main():
    # setup parameters
    args = SimpleNamespace(
        env_module="environments",
        env_name="TargetEnv-v0",
        device="cuda:0" if torch.cuda.is_available() else "cpu",
        num_parallel=100,
        vae_path="models/",
        frame_skip=1,
        seed=16,
        load_saved_model=False,
    )

    args.num_parallel *= args.frame_skip
    env = make_gym_environment(args)

    # env parameters
    args.action_size = env.action_space.shape[0]
    args.observation_size = env.observation_space.shape[0]

    # other configs
    args.save_path = os.path.join(current_dir, "con_" + args.env_name + ".pt")

    # sampling parameters
    args.num_frames = 10e7
    args.num_steps_per_rollout = env.unwrapped.max_timestep
    args.num_updates = int(args.num_frames / args.num_parallel /
                           args.num_steps_per_rollout)

    # learning parameters
    args.lr = 3e-5
    args.final_lr = 1e-5
    args.eps = 1e-5
    args.lr_decay_type = "exponential"
    args.mini_batch_size = 1000
    args.num_mini_batch = (args.num_parallel * args.num_steps_per_rollout //
                           args.mini_batch_size)

    # ppo parameters
    use_gae = True
    entropy_coef = 0.0
    value_loss_coef = 1.0
    ppo_epoch = 10
    gamma = 0.99
    gae_lambda = 0.95
    clip_param = 0.2
    max_grad_norm = 1.0

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    obs_shape = env.observation_space.shape
    obs_shape = (obs_shape[0], *obs_shape[1:])

    if args.load_saved_model:
        actor_critic = torch.load(args.save_path, map_location=args.device)
        print("Loading model:", args.save_path)
    else:
        controller = PoseVAEController(env)
        actor_critic = PoseVAEPolicy(controller)

    actor_critic = actor_critic.to(args.device)
    actor_critic.env_info = {"frame_skip": args.frame_skip}

    agent = PPO(
        actor_critic,
        clip_param,
        ppo_epoch,
        args.num_mini_batch,
        value_loss_coef,
        entropy_coef,
        lr=args.lr,
        eps=args.eps,
        max_grad_norm=max_grad_norm,
    )

    rollouts = RolloutStorage(
        args.num_steps_per_rollout,
        args.num_parallel,
        obs_shape,
        args.action_size,
        actor_critic.state_size,
    )
    obs = env.reset()
    rollouts.observations[0].copy_(obs)
    rollouts.to(args.device)

    log_path = os.path.join(current_dir,
                            "log_ppo_progress-{}".format(args.env_name))
    logger = StatsLogger(csv_path=log_path)

    for update in range(args.num_updates):

        ep_info = {"reward": []}
        ep_reward = 0

        if args.lr_decay_type == "linear":
            update_linear_schedule(agent.optimizer, update, args.num_updates,
                                   args.lr, args.final_lr)
        elif args.lr_decay_type == "exponential":
            update_exponential_schedule(agent.optimizer, update, 0.99, args.lr,
                                        args.final_lr)

        for step in range(args.num_steps_per_rollout):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.observations[step])

            obs, reward, done, info = env.step(action)
            ep_reward += reward

            end_of_rollout = info.get("reset")
            masks = (~done).float()
            bad_masks = (~(done * end_of_rollout)).float()

            if done.any():
                ep_info["reward"].append(ep_reward[done].clone())
                ep_reward *= (~done).float()  # zero out the dones
                reset_indices = env.parallel_ind_buf.masked_select(
                    done.squeeze())
                obs = env.reset(reset_indices)

            if end_of_rollout:
                obs = env.reset()

            rollouts.insert(obs, action, action_log_prob, value, reward, masks,
                            bad_masks)

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.observations[-1]).detach()

        rollouts.compute_returns(next_value, use_gae, gamma, gae_lambda)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        torch.save(copy.deepcopy(actor_critic).cpu(), args.save_path)

        ep_info["reward"] = torch.cat(ep_info["reward"])
        logger.log_stats(
            args,
            {
                "update": update,
                "ep_info": ep_info,
                "dist_entropy": dist_entropy,
                "value_loss": value_loss,
                "action_loss": action_loss,
            },
        )
Beispiel #2
0
def main():

    # --- ARGUMENTS ---
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument('--model',
                        type=str,
                        default='ppo',
                        help='the model to use for training.')
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # get arguments
    args = args_pretrain_aup.get_args(rest_args)
    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(name=args.run_name,
               project=args.proj_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    wandb.config.update(args)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    # --- OBTAIN DATA FOR TRAINING R_aux ---
    print("Gathering data for R_aux Model.")

    # gather observations for pretraining the auxiliary reward function (CB-VAE)
    envs = make_vec_envs(env_name=args.env_name,
                         start_level=0,
                         num_levels=0,
                         distribution_mode=args.distribution_mode,
                         paint_vel_info=args.paint_vel_info,
                         num_processes=args.num_processes,
                         num_frame_stack=args.num_frame_stack,
                         device=device)

    # number of frames ÷ number of policy steps before update ÷ number of cpu processes
    num_batch = args.num_processes * args.policy_num_steps
    num_updates = int(args.num_frames_r_aux) // num_batch

    # create list to store env observations
    obs_data = torch.zeros(num_updates * args.policy_num_steps + 1,
                           args.num_processes, *envs.observation_space.shape)
    # reset environments
    obs = envs.reset()  # obs.shape = (n_env,C,H,W)
    obs_data[0].copy_(obs)
    obs = obs.to(device)

    for iter_idx in range(num_updates):
        # rollout policy to collect num_batch of experience and store in storage
        for step in range(args.policy_num_steps):
            # sample actions from random agent
            action = torch.randint(0, envs.action_space.n,
                                   (args.num_processes, 1))
            # observe rewards and next obs
            obs, reward, done, infos = envs.step(action)
            # store obs
            obs_data[1 + iter_idx * args.policy_num_steps + step].copy_(obs)
    # close envs
    envs.close()

    # --- TRAIN R_aux (CB-VAE) ---
    # define CB-VAE where the encoder will be used as the auxiliary reward function R_aux
    print("Training R_aux Model.")

    # create dataloader for observations gathered
    obs_data = obs_data.reshape(-1, *envs.observation_space.shape)
    sampler = BatchSampler(SubsetRandomSampler(range(obs_data.size(0))),
                           args.cb_vae_batch_size,
                           drop_last=False)

    # initialise CB-VAE
    cb_vae = CBVAE(obs_shape=envs.observation_space.shape,
                   latent_dim=args.cb_vae_latent_dim).to(device)
    # optimiser
    optimiser = torch.optim.Adam(cb_vae.parameters(),
                                 lr=args.cb_vae_learning_rate)
    # put CB-VAE into train mode
    cb_vae.train()

    measures = defaultdict(list)
    for epoch in range(args.cb_vae_epochs):
        print("Epoch: ", epoch)
        start_time = time.time()
        batch_loss = 0
        for indices in sampler:
            obs = obs_data[indices].to(device)
            # zero accumulated gradients
            cb_vae.zero_grad()
            # forward pass through CB-VAE
            recon_batch, mu, log_var = cb_vae(obs)
            # calculate loss
            loss = cb_vae_loss(recon_batch, obs, mu, log_var)
            # backpropogation: calculating gradients
            loss.backward()
            # update parameters of generator
            optimiser.step()

            # save loss per mini-batch
            batch_loss += loss.item() * obs.size(0)
        # log losses per epoch
        wandb.log({
            'cb_vae/loss': batch_loss / obs_data.size(0),
            'cb_vae/time_taken': time.time() - start_time,
            'cb_vae/epoch': epoch
        })
    indices = np.random.randint(0, obs.size(0), args.cb_vae_num_samples**2)
    measures['true_images'].append(obs[indices].detach().cpu().numpy())
    measures['recon_images'].append(
        recon_batch[indices].detach().cpu().numpy())

    # plot ground truth images
    plt.rcParams.update({'font.size': 10})
    fig, axs = plt.subplots(args.cb_vae_num_samples,
                            args.cb_vae_num_samples,
                            figsize=(20, 20))
    for i, img in enumerate(measures['true_images'][0]):
        axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow(
            img.transpose(1, 2, 0))
        axs[i // args.cb_vae_num_samples][i %
                                          args.cb_vae_num_samples].axis('off')
    wandb.log({"Ground Truth Images": wandb.Image(plt)})
    # plot reconstructed images
    fig, axs = plt.subplots(args.cb_vae_num_samples,
                            args.cb_vae_num_samples,
                            figsize=(20, 20))
    for i, img in enumerate(measures['recon_images'][0]):
        axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow(
            img.transpose(1, 2, 0))
        axs[i // args.cb_vae_num_samples][i %
                                          args.cb_vae_num_samples].axis('off')
    wandb.log({"Reconstructed Images": wandb.Image(plt)})

    # --- TRAIN Q_aux --
    # train PPO agent with value head replaced with action-value head and training on R_aux instead of the environment R
    print("Training Q_aux Model.")

    # initialise environments for training Q_aux
    envs = make_vec_envs(env_name=args.env_name,
                         start_level=0,
                         num_levels=0,
                         distribution_mode=args.distribution_mode,
                         paint_vel_info=args.paint_vel_info,
                         num_processes=args.num_processes,
                         num_frame_stack=args.num_frame_stack,
                         device=device)

    # initialise policy network
    actor_critic = QModel(obs_shape=envs.observation_space.shape,
                          action_space=envs.action_space,
                          hidden_size=args.hidden_size).to(device)

    # initialise policy trainer
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=envs.observation_space.shape,
                              action_space=envs.action_space)

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    update_start_time = time.time()
    # reset environments
    obs = envs.reset()  # obs.shape = (n_envs,C,H,W)
    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of cpu processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames_q_aux) // args.num_batch
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        # rollout policy to collect num_batch of experience and store in storage
        for step in range(args.policy_num_steps):

            with torch.no_grad():
                # sample actions from policy
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step])
                # obtain reward R_aux from encoder of CB-VAE
                r_aux, _, _ = cb_vae.encode(rollouts.obs[step])

            # observe rewards and next obs
            obs, _, done, infos = envs.step(action)

            # log episode info if episode finished
            for i, info in enumerate(infos):
                if 'episode' in info.keys():
                    episode_info_buf.append(info['episode'])
            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to policy buffer
            rollouts.insert(obs, r_aux, action, value, action_log_prob, masks)

            frames += args.num_processes

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy gradient algo
        total_loss, value_loss, action_loss, dist_entropy = policy.update(
            rollouts)

        # clean up after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:
            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            wandb.log({
                'q_aux_misc/timesteps':
                frames,
                'q_aux_misc/fps':
                fps,
                'q_aux_misc/explained_variance':
                float(ev),
                'q_aux_losses/total_loss':
                total_loss,
                'q_aux_losses/value_loss':
                value_loss,
                'q_aux_losses/action_loss':
                action_loss,
                'q_aux_losses/dist_entropy':
                dist_entropy,
                'q_aux_train/mean_episodic_return':
                utl_math.safe_mean(
                    [episode_info['r'] for episode_info in episode_info_buf]),
                'q_aux_train/mean_episodic_length':
                utl_math.safe_mean(
                    [episode_info['l'] for episode_info in episode_info_buf])
            })

    # close envs
    envs.close()

    # --- SAVE MODEL ---
    print("Saving Q_aux Model.")
    torch.save(actor_critic.state_dict(), args.q_aux_path)
Beispiel #3
0
def main(_seed, _config, _run):
    args = init(_seed, _config, _run)

    env_name = args.env_name

    dummy_env = make_env(env_name, render=False)

    cleanup_log_dir(args.log_dir)
    cleanup_log_dir(args.log_dir + "_test")

    try:
        os.makedirs(args.save_dir)
    except OSError:
        pass

    torch.set_num_threads(1)

    envs = make_vec_envs(env_name, args.seed, args.num_processes, args.log_dir)
    envs.set_mirror(args.use_phase_mirror)
    test_envs = make_vec_envs(env_name, args.seed, args.num_tests,
                              args.log_dir + "_test")
    test_envs.set_mirror(args.use_phase_mirror)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    if args.use_curriculum:
        curriculum = 0
        print("curriculum", curriculum)
        envs.update_curriculum(curriculum)
    if args.use_specialist:
        specialist = 0
        print("specialist", specialist)
        envs.update_specialist(specialist)
    if args.use_threshold_sampling:
        sampling_threshold = 200
        first_sampling = False
        uniform_sampling = True
        uniform_every = 500000
        uniform_counter = 1
        evaluate_envs = make_env(env_name, render=False)
        evaluate_envs.set_mirror(args.use_phase_mirror)
        evaluate_envs.update_curriculum(0)
        prob_filter = np.zeros((11, 11))
        prob_filter[5, 5] = 1
    if args.use_adaptive_sampling:
        evaluate_envs = make_env(env_name, render=False)
        evaluate_envs.set_mirror(args.use_phase_mirror)
        evaluate_envs.update_curriculum(0)
    if args.plot_prob:
        import matplotlib.pyplot as plt
        fig = plt.figure()
        plt.show(block=False)
        ax1 = fig.add_subplot(121)
        ax2 = fig.add_subplot(122)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0], *obs_shape[1:])

    if args.load_saved_controller:
        best_model = "{}_base.pt".format(env_name)
        model_path = os.path.join(current_dir, "models", best_model)
        print("Loading model {}".format(best_model))
        actor_critic = torch.load(model_path)
        actor_critic.reset_dist()
    else:
        controller = SoftsignActor(dummy_env)
        actor_critic = Policy(controller, num_ensembles=args.num_ensembles)

    mirror_function = None
    if args.use_mirror:
        indices = dummy_env.unwrapped.get_mirror_indices()
        mirror_function = get_mirror_function(indices)

    device = "cuda:0" if args.cuda else "cpu"
    if args.cuda:
        actor_critic.cuda()

    agent = PPO(actor_critic,
                mirror_function=mirror_function,
                **args.ppo_params)

    rollouts = RolloutStorage(
        args.num_steps,
        args.num_processes,
        obs_shape,
        envs.action_space.shape[0],
        actor_critic.state_size,
    )

    current_obs = torch.zeros(args.num_processes, *obs_shape)

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    episode_rewards = deque(maxlen=args.num_processes)
    test_episode_rewards = deque(maxlen=args.num_tests)
    num_updates = int(args.num_frames) // args.num_steps // args.num_processes

    start = time.time()
    next_checkpoint = args.save_every
    max_ep_reward = float("-inf")

    logger = ConsoleCSVLogger(log_dir=args.experiment_dir,
                              console_log_interval=args.log_interval)

    update_values = False

    if args.save_sampling_prob:
        sampling_prob_list = []

    for j in range(num_updates):

        if args.lr_decay_type == "linear":
            scheduled_lr = linear_decay(j, num_updates, args.lr, final_value=0)
        elif args.lr_decay_type == "exponential":
            scheduled_lr = exponential_decay(j,
                                             0.99,
                                             args.lr,
                                             final_value=3e-5)
        else:
            scheduled_lr = args.lr

        set_optimizer_lr(agent.optimizer, scheduled_lr)

        ac_state_dict = copy.deepcopy(actor_critic).cpu().state_dict()

        if update_values and args.use_threshold_sampling:
            envs.update_curriculum(5)
        elif (not update_values
              ) and args.use_threshold_sampling and first_sampling:
            envs.update_specialist(0)

        if args.use_threshold_sampling and not uniform_sampling:
            obs = evaluate_envs.reset()
            yaw_size = dummy_env.yaw_samples.shape[0]
            pitch_size = dummy_env.pitch_samples.shape[0]
            total_metric = torch.zeros(1, yaw_size * pitch_size).to(device)
            evaluate_counter = 0
            while True:
                obs = torch.from_numpy(obs).float().unsqueeze(0).to(device)
                with torch.no_grad():
                    _, action, _, _ = actor_critic.act(obs,
                                                       None,
                                                       None,
                                                       deterministic=True)
                cpu_actions = action.squeeze().cpu().numpy()
                obs, reward, done, info = evaluate_envs.step(cpu_actions)
                if done:
                    obs = evaluate_envs.reset()
                if evaluate_envs.update_terrain:
                    evaluate_counter += 1
                    temp_states = evaluate_envs.create_temp_states()
                    with torch.no_grad():
                        temp_states = torch.from_numpy(temp_states).float().to(
                            device)
                        value_samples = actor_critic.get_ensemble_values(
                            temp_states, None, None)
                        #yaw_size = dummy_env.yaw_samples.shape[0]
                        mean = value_samples.mean(dim=-1)
                        #mean = value_samples.min(dim=-1)[0]
                        metric = mean.clone()
                        metric = metric.view(yaw_size, pitch_size)
                        #metric = metric / (metric.abs().max())
                        metric = metric.view(1, yaw_size * pitch_size)
                        total_metric += metric
                if evaluate_counter >= 5:
                    total_metric /= (total_metric.abs().max())
                    #total_metric[total_metric < 0.7] = 0
                    print("metric", total_metric)
                    sampling_probs = (
                        -10 * (total_metric - args.curriculum_threshold).abs()
                    ).softmax(dim=1).view(
                        yaw_size, pitch_size
                    )  #threshold1:150, 0.9 l2, threshold2: 10, 0.85 l1, threshold3: 10, 0.85, l1, 0.40 gap
                    #threshold 4: 20, 0.85, l1, yaw 10
                    if args.save_sampling_prob:
                        sampling_prob_list.append(sampling_probs.cpu().numpy())
                    sample_probs = np.zeros(
                        (args.num_processes, yaw_size, pitch_size))
                    #print("prob", sampling_probs)
                    for i in range(args.num_processes):
                        sample_probs[i, :, :] = np.copy(
                            sampling_probs.cpu().numpy().astype(np.float64))
                    envs.update_sample_prob(sample_probs)
                    break
        elif args.use_threshold_sampling and uniform_sampling:
            envs.update_curriculum(5)
        # if args.use_threshold_sampling and not uniform_sampling:
        #     obs = evaluate_envs.reset()
        #     yaw_size = dummy_env.yaw_samples.shape[0]
        #     pitch_size = dummy_env.pitch_samples.shape[0]
        #     r_size = dummy_env.r_samples.shape[0]
        #     total_metric = torch.zeros(1, yaw_size * pitch_size * r_size).to(device)
        #     evaluate_counter = 0
        #     while True:
        #         obs = torch.from_numpy(obs).float().unsqueeze(0).to(device)
        #         with torch.no_grad():
        #             _, action, _, _ = actor_critic.act(
        #             obs, None, None, deterministic=True
        #             )
        #         cpu_actions = action.squeeze().cpu().numpy()
        #         obs, reward, done, info = evaluate_envs.step(cpu_actions)
        #         if done:
        #             obs = evaluate_envs.reset()
        #         if evaluate_envs.update_terrain:
        #             evaluate_counter += 1
        #             temp_states = evaluate_envs.create_temp_states()
        #             with torch.no_grad():
        #                 temp_states = torch.from_numpy(temp_states).float().to(device)
        #                 value_samples = actor_critic.get_ensemble_values(temp_states, None, None)
        #                 mean = value_samples.mean(dim=-1)
        #                 #mean = value_samples.min(dim=-1)[0]
        #                 metric = mean.clone()
        #                 metric = metric.view(yaw_size, pitch_size, r_size)
        #                 #metric = metric / (metric.abs().max())
        #                 metric = metric.view(1, yaw_size*pitch_size*r_size)
        #                 total_metric += metric
        #         if evaluate_counter >= 5:
        #             total_metric /= (total_metric.abs().max())
        #             #total_metric[total_metric < 0.7] = 0
        #             #print("metric", total_metric)
        #             sampling_probs = (-10*(total_metric-0.85).abs()).softmax(dim=1).view(yaw_size, pitch_size, r_size) #threshold1:150, 0.9 l2, threshold2: 10, 0.85 l1, threshold3: 10, 0.85, l1, 0.40 gap
        #             #threshold 4: 3d grid, 10, 0.85, l1
        #             sample_probs = np.zeros((args.num_processes, yaw_size, pitch_size, r_size))
        #             #print("prob", sampling_probs)
        #             for i in range(args.num_processes):
        #                 sample_probs[i, :, :, :] = np.copy(sampling_probs.cpu().numpy().astype(np.float64))
        #             envs.update_sample_prob(sample_probs)
        #             break
        # elif args.use_threshold_sampling and uniform_sampling:
        #     envs.update_curriculum(5)

        if args.use_adaptive_sampling:
            obs = evaluate_envs.reset()
            yaw_size = dummy_env.yaw_samples.shape[0]
            pitch_size = dummy_env.pitch_samples.shape[0]
            total_metric = torch.zeros(1, yaw_size * pitch_size).to(device)
            evaluate_counter = 0
            while True:
                obs = torch.from_numpy(obs).float().unsqueeze(0).to(device)
                with torch.no_grad():
                    _, action, _, _ = actor_critic.act(obs,
                                                       None,
                                                       None,
                                                       deterministic=True)
                cpu_actions = action.squeeze().cpu().numpy()
                obs, reward, done, info = evaluate_envs.step(cpu_actions)
                if done:
                    obs = evaluate_envs.reset()
                if evaluate_envs.update_terrain:
                    evaluate_counter += 1
                    temp_states = evaluate_envs.create_temp_states()
                    with torch.no_grad():
                        temp_states = torch.from_numpy(temp_states).float().to(
                            device)
                        value_samples = actor_critic.get_ensemble_values(
                            temp_states, None, None)
                        mean = value_samples.mean(dim=-1)
                        metric = mean.clone()
                        metric = metric.view(yaw_size, pitch_size)
                        #metric = metric / metric.abs().max()
                        metric = metric.view(1, yaw_size * pitch_size)
                        total_metric += metric
                        # sampling_probs = (-30*metric).softmax(dim=1).view(size, size)
                        # sample_probs = np.zeros((args.num_processes, size, size))
                        # for i in range(args.num_processes):
                        #     sample_probs[i, :, :] = np.copy(sampling_probs.cpu().numpy().astype(np.float64))
                        # envs.update_sample_prob(sample_probs)
                if evaluate_counter >= 5:
                    total_metric /= (total_metric.abs().max())
                    print("metric", total_metric)
                    sampling_probs = (-10 * total_metric).softmax(dim=1).view(
                        yaw_size, pitch_size)
                    sample_probs = np.zeros(
                        (args.num_processes, yaw_size, pitch_size))
                    for i in range(args.num_processes):
                        sample_probs[i, :, :] = np.copy(
                            sampling_probs.cpu().numpy().astype(np.float64))
                    envs.update_sample_prob(sample_probs)
                    break

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step],
                    rollouts.states[step],
                    rollouts.masks[step],
                    deterministic=update_values)
            cpu_actions = action.squeeze(1).cpu().numpy()

            obs, reward, done, infos = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()

            if args.plot_prob and step == 0:
                temp_states = envs.create_temp_states()
                with torch.no_grad():
                    temp_states = torch.from_numpy(temp_states).float().to(
                        device)
                    value_samples = actor_critic.get_value(
                        temp_states, None, None)
                size = dummy_env.yaw_samples.shape[0]
                v = value_samples.mean(dim=0).view(size, size).cpu().numpy()
                vs = value_samples.var(dim=0).view(size, size).cpu().numpy()
                ax1.pcolormesh(v)
                ax2.pcolormesh(vs)
                print(np.round(v, 2))
                fig.canvas.draw()

            # if args.use_adaptive_sampling:
            #     temp_states = envs.create_temp_states()
            #     with torch.no_grad():
            #         temp_states = torch.from_numpy(temp_states).float().to(device)
            #         value_samples = actor_critic.get_value(temp_states, None, None)

            #     size = dummy_env.yaw_samples.shape[0]
            #     sample_probs = (-value_samples / 5).softmax(dim=1).view(args.num_processes, size, size)
            #     envs.update_sample_prob(sample_probs.cpu().numpy())

            # if args.use_threshold_sampling and not uniform_sampling:
            #     temp_states = envs.create_temp_states()
            #     with torch.no_grad():
            #         temp_states = torch.from_numpy(temp_states).float().to(device)
            #         value_samples = actor_critic.get_ensemble_values(temp_states, None, None)
            #     size = dummy_env.yaw_samples.shape[0]
            #     mean = value_samples.mean(dim=-1)
            #     std = value_samples.std(dim=-1)

            #using std
            # metric = std.clone()
            # metric = metric.view(args.num_processes, size, size)
            # value_filter = torch.ones(args.num_processes, 11, 11).to(device) * -1e5
            # value_filter[:, 5 - curriculum: 5 + curriculum + 1, 5 - curriculum: 5 + curriculum + 1] = 0
            # metric = metric / metric.max() + value_filter
            # metric = metric.view(args.num_processes, size*size)
            # sample_probs = (30*metric).softmax(dim=1).view(args.num_processes, size, size)

            #using value estimate
            # metric = mean.clone()
            # metric = metric.view(args.num_processes, size, size)
            # value_filter = torch.ones(args.num_processes, 11, 11).to(device) * -1e5
            # value_filter[:, 5 - curriculum: 5 + curriculum + 1, 5 - curriculum: 5 + curriculum + 1] = 0
            # metric = metric / metric.abs().max() - value_filter
            # metric = metric.view(args.num_processes, size*size)
            # sample_probs = (-30*metric).softmax(dim=1).view(args.num_processes, size, size)

            # if args.plot_prob and step == 0:
            #     #print(sample_probs.cpu().numpy()[0, :, :])
            #     ax.pcolormesh(sample_probs.cpu().numpy()[0, :, :])
            #     print(np.round(sample_probs.cpu().numpy()[0, :, :], 4))
            #     fig.canvas.draw()
            # envs.update_sample_prob(sample_probs.cpu().numpy())

            #using value threshold
            # metric = mean.clone()
            # metric = metric.view(args.num_processes, size, size)
            # metric = metric / metric.abs().max()# - value_filter
            # metric = metric.view(args.num_processes, size*size)
            # sample_probs = (-30*(metric-0.8)**2).softmax(dim=1).view(args.num_processes, size, size)

            # if args.plot_prob and step == 0:
            #     ax.pcolormesh(sample_probs.cpu().numpy()[0, :, :])
            #     print(np.round(sample_probs.cpu().numpy()[0, :, :], 4))
            #     fig.canvas.draw()
            # envs.update_sample_prob(sample_probs.cpu().numpy())

            bad_masks = np.ones((args.num_processes, 1))
            for p_index, info in enumerate(infos):
                keys = info.keys()
                # This information is added by algorithms.utils.TimeLimitMask
                if "bad_transition" in keys:
                    bad_masks[p_index] = 0.0
                # This information is added by baselines.bench.Monitor
                if "episode" in keys:
                    episode_rewards.append(info["episode"]["r"])

            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            bad_masks = torch.from_numpy(bad_masks)

            update_current_obs(obs)
            rollouts.insert(
                current_obs,
                states,
                action,
                action_log_prob,
                value,
                reward,
                masks,
                bad_masks,
            )

        obs = test_envs.reset()
        if args.use_threshold_sampling:
            if uniform_counter % uniform_every == 0:
                uniform_sampling = True
                uniform_counter = 0
            else:
                uniform_sampling = False
            uniform_counter += 1
            if uniform_sampling:
                envs.update_curriculum(5)
                print("uniform")

        #print("max_step", dummy_env._max_episode_steps)
        for step in range(dummy_env._max_episode_steps):
            # Sample actions
            with torch.no_grad():
                obs = torch.from_numpy(obs).float().to(device)
                _, action, _, _ = actor_critic.act(obs,
                                                   None,
                                                   None,
                                                   deterministic=True)
            cpu_actions = action.squeeze(1).cpu().numpy()

            obs, reward, done, infos = test_envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward),
                                                     1)).float()

            for p_index, info in enumerate(infos):
                keys = info.keys()
                # This information is added by baselines.bench.Monitor
                if "episode" in keys:
                    #print(info["episode"]["r"])
                    test_episode_rewards.append(info["episode"]["r"])

        if args.use_curriculum and np.mean(
                episode_rewards) > 1000 and curriculum <= 4:
            curriculum += 1
            print("curriculum", curriculum)
            envs.update_curriculum(curriculum)

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.observations[-1],
                                                rollouts.states[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.gae_lambda)

        if update_values:
            value_loss = agent.update_values(rollouts)
        else:
            value_loss, action_loss, dist_entropy = agent.update(rollouts)
        #update_values = (not update_values)

        rollouts.after_update()

        frame_count = (j + 1) * args.num_steps * args.num_processes
        if (frame_count >= next_checkpoint
                or j == num_updates - 1) and args.save_dir != "":
            model_name = "{}_{:d}.pt".format(env_name, int(next_checkpoint))
            next_checkpoint += args.save_every
        else:
            model_name = "{}_latest.pt".format(env_name)

        if args.save_sampling_prob:
            import pickle
            with open('{}_sampling_prob85.pkl'.format(env_name), 'wb') as fp:
                pickle.dump(sampling_prob_list, fp)

        # A really ugly way to save a model to CPU
        save_model = actor_critic
        if args.cuda:
            save_model = copy.deepcopy(actor_critic).cpu()

        if args.use_specialist and np.mean(
                episode_rewards) > 1000 and specialist <= 4:
            specialist_name = "{}_specialist_{:d}.pt".format(
                env_name, int(specialist))
            specialist_model = actor_critic
            if args.cuda:
                specialist_model = copy.deepcopy(actor_critic).cpu()
            torch.save(specialist_model,
                       os.path.join(args.save_dir, specialist_name))
            specialist += 1
            envs.update_specialist(specialist)
        # if args.use_threshold_sampling and np.mean(episode_rewards) > 1000 and curriculum <= 4:
        #     first_sampling = False
        #     curriculum += 1
        #     print("curriculum", curriculum)
        #     envs.update_curriculum(curriculum)
        #     prob_filter[5-curriculum:5+curriculum+1, 5-curriculum:5+curriculum+1] = 1

        torch.save(save_model, os.path.join(args.save_dir, model_name))

        if len(episode_rewards) > 1 and np.mean(
                episode_rewards) > max_ep_reward:
            model_name = "{}_best.pt".format(env_name)
            max_ep_reward = np.mean(episode_rewards)
            torch.save(save_model, os.path.join(args.save_dir, model_name))

        if len(episode_rewards) > 1:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            logger.log_epoch({
                "iter": j + 1,
                "total_num_steps": total_num_steps,
                "fps": int(total_num_steps / (end - start)),
                "entropy": dist_entropy,
                "value_loss": value_loss,
                "action_loss": action_loss,
                "stats": {
                    "rew": episode_rewards
                },
                "test_stats": {
                    "rew": test_episode_rewards
                },
            })
Beispiel #4
0
def main(_seed, _config, _run):
    args = init(_seed, _config, _run, post_config=post_config)

    env_name = args.env_name

    dummy_env = make_env(env_name, render=False)

    cleanup_log_dir(args.log_dir)

    try:
        os.makedirs(args.save_dir)
    except OSError:
        pass

    torch.set_num_threads(1)

    envs = make_vec_envs(env_name, args.seed, args.num_processes, args.log_dir)

    obs_shape = envs.observation_space.shape
    obs_shape = (obs_shape[0], *obs_shape[1:])

    if args.load_saved_controller:
        best_model = "{}_best.pt".format(env_name)
        model_path = os.path.join(current_dir, "models", best_model)
        print("Loading model {}".format(best_model))
        actor_critic = torch.load(model_path)
    else:
        if args.mirror_method == MirrorMethods.net2:
            controller = SymmetricNetV2(
                *dummy_env.unwrapped.mirror_sizes,
                num_layers=6,
                hidden_size=256,
                tanh_finish=True
            )
        else:
            controller = SoftsignActor(dummy_env)
            if args.mirror_method == MirrorMethods.net:
                controller = SymmetricNet(controller, *dummy_env.unwrapped.sym_act_inds)
        actor_critic = Policy(controller)
        if args.sym_value_net:
            actor_critic.critic = SymmetricVNet(
                actor_critic.critic, controller.state_dim
            )

    mirror_function = None
    if (
        args.mirror_method == MirrorMethods.traj
        or args.mirror_method == MirrorMethods.loss
    ):
        indices = dummy_env.unwrapped.get_mirror_indices()
        mirror_function = get_mirror_function(indices)

    if args.cuda:
        actor_critic.cuda()

    agent = PPO(actor_critic, mirror_function=mirror_function, **args.ppo_params)

    rollouts = RolloutStorage(
        args.num_steps,
        args.num_processes,
        obs_shape,
        envs.action_space.shape[0],
        actor_critic.state_size,
    )
    current_obs = torch.zeros(args.num_processes, *obs_shape)

    def update_current_obs(obs):
        shape_dim0 = envs.observation_space.shape[0]
        obs = torch.from_numpy(obs).float()
        current_obs[:, -shape_dim0:] = obs

    obs = envs.reset()
    update_current_obs(obs)

    rollouts.observations[0].copy_(current_obs)

    if args.cuda:
        current_obs = current_obs.cuda()
        rollouts.cuda()

    episode_rewards = deque(maxlen=args.num_processes)
    num_updates = int(args.num_frames) // args.num_steps // args.num_processes

    start = time.time()
    next_checkpoint = args.save_every
    max_ep_reward = float("-inf")

    logger = ConsoleCSVLogger(
        log_dir=args.experiment_dir, console_log_interval=args.log_interval
    )

    for j in range(num_updates):

        if args.lr_decay_type == "linear":
            scheduled_lr = linear_decay(j, num_updates, args.lr, final_value=0)
        elif args.lr_decay_type == "exponential":
            scheduled_lr = exponential_decay(j, 0.99, args.lr, final_value=3e-5)
        else:
            scheduled_lr = args.lr

        set_optimizer_lr(agent.optimizer, scheduled_lr)

        for step in range(args.num_steps):
            # Sample actions
            with torch.no_grad():
                value, action, action_log_prob, states = actor_critic.act(
                    rollouts.observations[step],
                    rollouts.states[step],
                    rollouts.masks[step],
                )
            cpu_actions = action.squeeze(1).cpu().numpy()

            obs, reward, done, infos = envs.step(cpu_actions)
            reward = torch.from_numpy(np.expand_dims(np.stack(reward), 1)).float()

            bad_masks = np.ones((args.num_processes, 1))
            for p_index, info in enumerate(infos):
                keys = info.keys()
                # This information is added by algorithms.utils.TimeLimitMask
                if "bad_transition" in keys:
                    bad_masks[p_index] = 0.0
                # This information is added by baselines.bench.Monitor
                if "episode" in keys:
                    episode_rewards.append(info["episode"]["r"])

            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done])
            bad_masks = torch.from_numpy(bad_masks)

            update_current_obs(obs)
            rollouts.insert(
                current_obs,
                states,
                action,
                action_log_prob,
                value,
                reward,
                masks,
                bad_masks,
            )

        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.observations[-1], rollouts.states[-1], rollouts.masks[-1]
            ).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda)

        value_loss, action_loss, dist_entropy = agent.update(rollouts)

        rollouts.after_update()

        frame_count = (j + 1) * args.num_steps * args.num_processes
        if (
            frame_count >= next_checkpoint or j == num_updates - 1
        ) and args.save_dir != "":
            model_name = "{}_{:d}.pt".format(env_name, int(next_checkpoint))
            next_checkpoint += args.save_every
        else:
            model_name = "{}_latest.pt".format(env_name)

        # A really ugly way to save a model to CPU
        save_model = actor_critic
        if args.cuda:
            save_model = copy.deepcopy(actor_critic).cpu()
        drive=1
        if drive:
          #print("save")
          torch.save(save_model, os.path.join("/content/gdrive/My Drive/darwin", model_name))
        torch.save(save_model, os.path.join(args.save_dir, model_name))

        if len(episode_rewards) > 1 and np.mean(episode_rewards) > max_ep_reward:
            model_name = "{}_best.pt".format(env_name)
            max_ep_reward = np.mean(episode_rewards)
            drive=1
            if drive:
              #print("max_ep_reward",max_ep_reward)
              torch.save(save_model, os.path.join("/content/gdrive/My Drive/darwin", model_name))
            torch.save(save_model, os.path.join(args.save_dir, model_name))  

        if len(episode_rewards) > 1:
            end = time.time()
            total_num_steps = (j + 1) * args.num_processes * args.num_steps
            logger.log_epoch(
                {
                    "iter": j + 1,
                    "total_num_steps": total_num_steps,
                    "fps": int(total_num_steps / (end - start)),
                    "entropy": dist_entropy,
                    "value_loss": value_loss,
                    "action_loss": action_loss,
                    "stats": {"rew": episode_rewards},
                }
            )
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument('--model',
                        type=str,
                        default='ppo',
                        help='the model to use for training. {ppo, ppo_aup}')
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # --- ARGUMENTS ---

    if model == 'ppo':
        args = args_ppo.get_args(rest_args)
    elif model == 'ppo_aup':
        args = args_ppo_aup.get_args(rest_args)
    else:
        raise NotImplementedError

    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model

    # warnings
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError(
                'If you want fully deterministic code, run it with num_processes=1.'
                'Warning: This will slow things down and might break A2C if '
                'policy_num_steps < env._max_episode_steps.')

    # --- TRAINING ---
    print("Setting up wandb logging.")

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(project=args.proj_name,
               name=args.run_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    # make directory for saving models
    save_dir = os.path.join(wandb.run.dir, 'models')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    print("Setting up Environments.")
    # initialise environments for training
    train_envs = make_vec_envs(env_name=args.env_name,
                               start_level=args.train_start_level,
                               num_levels=args.train_num_levels,
                               distribution_mode=args.distribution_mode,
                               paint_vel_info=args.paint_vel_info,
                               num_processes=args.num_processes,
                               num_frame_stack=args.num_frame_stack,
                               device=device)
    # initialise environments for evaluation
    eval_envs = make_vec_envs(env_name=args.env_name,
                              start_level=0,
                              num_levels=0,
                              distribution_mode=args.distribution_mode,
                              paint_vel_info=args.paint_vel_info,
                              num_processes=args.num_processes,
                              num_frame_stack=args.num_frame_stack,
                              device=device)
    _ = eval_envs.reset()

    print("Setting up Actor-Critic model and Training algorithm.")
    # initialise policy network
    actor_critic = ACModel(obs_shape=train_envs.observation_space.shape,
                           action_space=train_envs.action_space,
                           hidden_size=args.hidden_size).to(device)

    # initialise policy training algorithm
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy training algorithm
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=train_envs.observation_space.shape,
                              action_space=train_envs.action_space)

    # initialise Q_aux function(s) for AUP
    if args.use_aup:
        print("Initialising Q_aux models.")
        q_aux = [
            QModel(obs_shape=train_envs.observation_space.shape,
                   action_space=train_envs.action_space,
                   hidden_size=args.hidden_size).to(device)
            for _ in range(args.num_q_aux)
        ]
        if args.num_q_aux == 1:
            # load weights to model
            path = args.q_aux_dir + "0.pt"
            q_aux[0].load_state_dict(torch.load(path))
            q_aux[0].eval()
        else:
            # get max number of q_aux functions to choose from
            args.max_num_q_aux = os.listdir(args.q_aux_dir)
            q_aux_models = random.sample(list(range(0, args.max_num_q_aux)),
                                         args.num_q_aux)
            # load weights to models
            for i, model in enumerate(q_aux):
                path = args.q_aux_dir + str(q_aux_models[i]) + ".pt"
                model.load_state_dict(torch.load(path))
                model.eval()

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    # update wandb args
    wandb.config.update(args)

    update_start_time = time.time()
    # reset environments
    obs = train_envs.reset()  # obs.shape = (num_processes,C,H,W)
    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames) // args.num_batch

    # define AUP coefficient
    if args.use_aup:
        aup_coef = args.aup_coef_start
        aup_linear_increase_val = math.exp(
            math.log(args.aup_coef_end / args.aup_coef_start) /
            args.num_updates)

    print("Training beginning.")
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        if args.use_aup:
            aup_measures = defaultdict(list)

        # rollout policy to collect num_batch of experience and place in storage
        for step in range(args.policy_num_steps):

            # sample actions from policy
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step])

            # observe rewards and next obs
            obs, reward, done, infos = train_envs.step(action)

            # calculate AUP reward
            if args.use_aup:
                intrinsic_reward = torch.zeros_like(reward)
                with torch.no_grad():
                    for model in q_aux:
                        # get action-values
                        action_values = model.get_action_value(
                            rollouts.obs[step])
                        # get action-value for action taken
                        action_value = torch.sum(
                            action_values * torch.nn.functional.one_hot(
                                action,
                                num_classes=train_envs.action_space.n).squeeze(
                                    dim=1),
                            dim=1)
                        # calculate the penalty
                        intrinsic_reward += torch.abs(
                            action_value.unsqueeze(dim=1) -
                            action_values[:, 4].unsqueeze(dim=1))
                intrinsic_reward /= args.num_q_aux
                # add intrinsic reward to the extrinsic reward
                reward -= aup_coef * intrinsic_reward
                # log the intrinsic reward from the first env.
                aup_measures['intrinsic_reward'].append(aup_coef *
                                                        intrinsic_reward[0, 0])
                if done[0] and infos[0]['prev_level_complete'] == 1:
                    aup_measures['episode_complete'].append(2)
                elif done[0] and infos[0]['prev_level_complete'] == 0:
                    aup_measures['episode_complete'].append(1)
                else:
                    aup_measures['episode_complete'].append(0)

            # log episode info if episode finished
            for info in infos:
                if 'episode' in info.keys():
                    episode_info_buf.append(info['episode'])

            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to storage
            rollouts.insert(obs, reward, action, value, action_log_prob, masks)

            frames += args.num_processes

        # linearly increase aup coefficient after every update
        if args.use_aup:
            aup_coef *= aup_linear_increase_val

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy training algorithm
        total_loss, value_loss, action_loss, dist_entropy = policy.update(
            rollouts)

        # clean up storage after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:

            # --- EVALUATION ---
            eval_episode_info_buf = utl_eval.evaluate(
                eval_envs=eval_envs, actor_critic=actor_critic, device=device)

            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # calculates whether the value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            if args.use_aup:
                step = frames - args.num_processes * args.policy_num_steps
                for i in range(args.policy_num_steps):
                    wandb.log(
                        {
                            'aup/intrinsic_reward':
                            aup_measures['intrinsic_reward'][i],
                            'aup/episode_complete':
                            aup_measures['episode_complete'][i]
                        },
                        step=step)
                    step += args.num_processes

            wandb.log(
                {
                    'misc/timesteps':
                    frames,
                    'misc/fps':
                    fps,
                    'misc/explained_variance':
                    float(ev),
                    'losses/total_loss':
                    total_loss,
                    'losses/value_loss':
                    value_loss,
                    'losses/action_loss':
                    action_loss,
                    'losses/dist_entropy':
                    dist_entropy,
                    'train/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r'] for episode_info in episode_info_buf
                    ]),
                    'train/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l'] for episode_info in episode_info_buf
                    ]),
                    'eval/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in eval_episode_info_buf
                    ]),
                    'eval/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in eval_episode_info_buf
                    ])
                },
                step=frames)

        # --- SAVE MODEL ---

        # save for every interval-th episode or for the last epoch
        if iter_idx != 0 and (iter_idx % args.save_interval == 0
                              or iter_idx == args.num_updates - 1):
            print("Saving Actor-Critic Model.")
            torch.save(actor_critic.state_dict(),
                       os.path.join(save_dir, "policy{0}.pt".format(iter_idx)))

    # close envs
    train_envs.close()
    eval_envs.close()

    # --- TEST ---

    if args.test:
        print("Testing beginning.")
        episodic_return = utl_test.test(args=args,
                                        actor_critic=actor_critic,
                                        device=device)

        # save returns from train and test levels to analyse using interactive mode
        train_levels = torch.arange(
            args.train_start_level,
            args.train_start_level + args.train_num_levels)
        for i, level in enumerate(train_levels):
            wandb.log({
                'test/train_levels': level,
                'test/train_returns': episodic_return[0][i]
            })
        test_levels = torch.arange(
            args.test_start_level,
            args.test_start_level + args.test_num_levels)
        for i, level in enumerate(test_levels):
            wandb.log({
                'test/test_levels': level,
                'test/test_returns': episodic_return[1][i]
            })
        # log returns from test envs
        wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[0])
        wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[1])
Beispiel #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument(
        '--model',
        type=str,
        default='ppo',
        help='the model to use for training. {ppo, ibac, ibac_sni, dist_match}'
    )
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # --- ARGUMENTS ---

    if model == 'ppo':
        args = args_ppo.get_args(rest_args)
    elif model == 'ibac':
        args = args_ibac.get_args(rest_args)
    elif model == 'ibac_sni':
        args = args_ibac_sni.get_args(rest_args)
    elif model == 'dist_match':
        args = args_dist_match.get_args(rest_args)
    else:
        raise NotImplementedError

    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model
    args.num_train_envs = args.num_processes - args.num_val_envs if args.num_val_envs > 0 else args.num_processes

    # warnings
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError(
                'If you want fully deterministic code, run it with num_processes=1.'
                'Warning: This will slow things down and might break A2C if '
                'policy_num_steps < env._max_episode_steps.')

    elif args.num_val_envs > 0 and (args.num_val_envs >= args.num_processes
                                    or not args.percentage_levels_train < 1.0):
        raise ValueError(
            'If --args.num_val_envs>0 then you must also have'
            '--num_val_envs < --num_processes and  0 < --percentage_levels_train < 1.'
        )

    elif args.num_val_envs > 0 and not args.use_dist_matching and args.dist_matching_coef != 0:
        raise ValueError(
            'If --num_val_envs>0 and --use_dist_matching=False then you must also have'
            '--dist_matching_coef=0.')

    elif args.use_dist_matching and not args.num_val_envs > 0:
        raise ValueError(
            'If --use_dist_matching=True then you must also have'
            '0 < --num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.'
        )

    elif args.analyse_rep and not args.use_bottleneck:
        raise ValueError('If --analyse_rep=True then you must also have'
                         '--use_bottleneck=True.')

    # --- TRAINING ---
    print("Setting up wandb logging.")

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(project=args.proj_name,
               name=args.run_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    # make directory for saving models
    save_dir = os.path.join(wandb.run.dir, 'models')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    # initialise environments for training
    print("Setting up Environments.")
    if args.num_val_envs > 0:
        train_num_levels = int(args.train_num_levels *
                               args.percentage_levels_train)
        val_start_level = args.train_start_level + train_num_levels
        val_num_levels = args.train_num_levels - train_num_levels
        train_envs = make_vec_envs(env_name=args.env_name,
                                   start_level=args.train_start_level,
                                   num_levels=train_num_levels,
                                   distribution_mode=args.distribution_mode,
                                   paint_vel_info=args.paint_vel_info,
                                   num_processes=args.num_train_envs,
                                   num_frame_stack=args.num_frame_stack,
                                   device=device)
        val_envs = make_vec_envs(env_name=args.env_name,
                                 start_level=val_start_level,
                                 num_levels=val_num_levels,
                                 distribution_mode=args.distribution_mode,
                                 paint_vel_info=args.paint_vel_info,
                                 num_processes=args.num_val_envs,
                                 num_frame_stack=args.num_frame_stack,
                                 device=device)
    else:
        train_envs = make_vec_envs(env_name=args.env_name,
                                   start_level=args.train_start_level,
                                   num_levels=args.train_num_levels,
                                   distribution_mode=args.distribution_mode,
                                   paint_vel_info=args.paint_vel_info,
                                   num_processes=args.num_processes,
                                   num_frame_stack=args.num_frame_stack,
                                   device=device)
    # initialise environments for evaluation
    eval_envs = make_vec_envs(env_name=args.env_name,
                              start_level=0,
                              num_levels=0,
                              distribution_mode=args.distribution_mode,
                              paint_vel_info=args.paint_vel_info,
                              num_processes=args.num_processes,
                              num_frame_stack=args.num_frame_stack,
                              device=device)
    _ = eval_envs.reset()
    # initialise environments for analysing the representation
    if args.analyse_rep:
        analyse_rep_train1_envs, analyse_rep_train2_envs, analyse_rep_val_envs, analyse_rep_test_envs = make_rep_analysis_envs(
            args, device)

    print("Setting up Actor-Critic model and Training algorithm.")
    # initialise policy network
    actor_critic = ACModel(obs_shape=train_envs.observation_space.shape,
                           action_space=train_envs.action_space,
                           hidden_size=args.hidden_size,
                           use_bottleneck=args.use_bottleneck,
                           sni_type=args.sni_type).to(device)

    # initialise policy training algorithm
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps,
                     vib_coef=args.vib_coef,
                     sni_coef=args.sni_coef,
                     use_dist_matching=args.use_dist_matching,
                     dist_matching_loss=args.dist_matching_loss,
                     dist_matching_coef=args.dist_matching_coef,
                     num_train_envs=args.num_train_envs,
                     num_val_envs=args.num_val_envs)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy training algorithm
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=train_envs.observation_space.shape,
                              action_space=train_envs.action_space)

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    # update wandb args
    wandb.config.update(args)
    # wandb.watch(actor_critic, log="all") # to log gradients of actor-critic network

    update_start_time = time.time()

    # reset environments
    if args.num_val_envs > 0:
        obs = torch.cat([train_envs.reset(),
                         val_envs.reset()])  # obs.shape = (n_envs,C,H,W)
    else:
        obs = train_envs.reset()  # obs.shape = (n_envs,C,H,W)

    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    train_episode_info_buf = deque(maxlen=10)
    val_episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames) // args.num_batch
    print("Training beginning.")
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        # rollout policy to collect num_batch of experience and place in storage
        for step in range(args.policy_num_steps):

            # sample actions from policy
            with torch.no_grad():
                value, action, action_log_prob, _ = actor_critic.act(
                    rollouts.obs[step])

            # observe rewards and next obs
            if args.num_val_envs > 0:
                obs, reward, done, infos = train_envs.step(
                    action[:args.num_train_envs, :])
                val_obs, val_reward, val_done, val_infos = val_envs.step(
                    action[args.num_train_envs:, :])
                obs = torch.cat([obs, val_obs])
                reward = torch.cat([reward, val_reward])
                done, val_done = list(done), list(val_done)
                done.extend(val_done)
                infos.extend(val_infos)
            else:
                obs, reward, done, infos = train_envs.step(action)

            # log episode info if episode finished
            for i, info in enumerate(infos):
                if i < args.num_train_envs and 'episode' in info.keys():
                    train_episode_info_buf.append(info['episode'])
                elif i >= args.num_train_envs and 'episode' in info.keys():
                    val_episode_info_buf.append(info['episode'])

            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to storage
            rollouts.insert(obs, reward, action, value, action_log_prob, masks)

            frames += args.num_processes

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy gradient algo
        total_loss, value_loss, action_loss, dist_entropy, vib_kl, dist_matching_loss = policy.update(
            rollouts)

        # clean up storage after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:

            # --- EVALUATION ---
            eval_episode_info_buf = utl_eval.evaluate(
                eval_envs=eval_envs, actor_critic=actor_critic, device=device)

            # --- ANALYSE REPRESENTATION ---
            if args.analyse_rep:
                rep_measures = utl_rep.analyse_rep(
                    args=args,
                    train1_envs=analyse_rep_train1_envs,
                    train2_envs=analyse_rep_train2_envs,
                    val_envs=analyse_rep_val_envs,
                    test_envs=analyse_rep_test_envs,
                    actor_critic=actor_critic,
                    device=device)

            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            wandb.log(
                {
                    'misc/timesteps':
                    frames,
                    'misc/fps':
                    fps,
                    'misc/explained_variance':
                    float(ev),
                    'losses/total_loss':
                    total_loss,
                    'losses/value_loss':
                    value_loss,
                    'losses/action_loss':
                    action_loss,
                    'losses/dist_entropy':
                    dist_entropy,
                    'train/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in train_episode_info_buf
                    ]),
                    'train/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in train_episode_info_buf
                    ]),
                    'eval/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in eval_episode_info_buf
                    ]),
                    'eval/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in eval_episode_info_buf
                    ])
                },
                step=iter_idx)
            if args.use_bottleneck:
                wandb.log({'losses/vib_kl': vib_kl}, step=iter_idx)
            if args.num_val_envs > 0:
                wandb.log(
                    {
                        'losses/dist_matching_loss':
                        dist_matching_loss,
                        'val/mean_episodic_return':
                        utl_math.safe_mean([
                            episode_info['r']
                            for episode_info in val_episode_info_buf
                        ]),
                        'val/mean_episodic_length':
                        utl_math.safe_mean([
                            episode_info['l']
                            for episode_info in val_episode_info_buf
                        ])
                    },
                    step=iter_idx)
            if args.analyse_rep:
                wandb.log(
                    {
                        "analysis/" + key: val
                        for key, val in rep_measures.items()
                    },
                    step=iter_idx)

        # --- SAVE MODEL ---

        # save for every interval-th episode or for the last epoch
        if iter_idx != 0 and (iter_idx % args.save_interval == 0
                              or iter_idx == args.num_updates - 1):
            print("Saving Actor-Critic Model.")
            torch.save(actor_critic.state_dict(),
                       os.path.join(save_dir, "policy{0}.pt".format(iter_idx)))

    # close envs
    train_envs.close()
    eval_envs.close()

    # --- TEST ---

    if args.test:
        print("Testing beginning.")
        episodic_return, latents_z = utl_test.test(args=args,
                                                   actor_critic=actor_critic,
                                                   device=device)

        # save returns from train and test levels to analyse using interactive mode
        train_levels = torch.arange(
            args.train_start_level,
            args.train_start_level + args.train_num_levels)
        for i, level in enumerate(train_levels):
            wandb.log({
                'test/train_levels': level,
                'test/train_returns': episodic_return[0][i]
            })
        test_levels = torch.arange(
            args.test_start_level,
            args.test_start_level + args.test_num_levels)
        for i, level in enumerate(test_levels):
            wandb.log({
                'test/test_levels': level,
                'test/test_returns': episodic_return[1][i]
            })
        # log returns from test envs
        wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[0])
        wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[1])

        # plot latent representation
        if args.plot_pca:
            print("Plotting PCA of Latent Representation.")
            utl_rep.pca(args, latents_z)