Beispiel #1
0
 def fit(self, paths, targvals):
     X = np.concatenate([self._preproc(p) for p in paths])
     y = np.concatenate(targvals)
     logger.record_tabular(
         "EVBefore", explained_variance(self._predict(X), y)
     )
     for _ in range(25):
         self.do_update(X, y)
     logger.record_tabular("EVAfter", explained_variance(self._predict(X), y))
Beispiel #2
0
def main():

    # --- ARGUMENTS ---
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument('--model',
                        type=str,
                        default='ppo',
                        help='the model to use for training.')
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # get arguments
    args = args_pretrain_aup.get_args(rest_args)
    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(name=args.run_name,
               project=args.proj_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    wandb.config.update(args)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    # --- OBTAIN DATA FOR TRAINING R_aux ---
    print("Gathering data for R_aux Model.")

    # gather observations for pretraining the auxiliary reward function (CB-VAE)
    envs = make_vec_envs(env_name=args.env_name,
                         start_level=0,
                         num_levels=0,
                         distribution_mode=args.distribution_mode,
                         paint_vel_info=args.paint_vel_info,
                         num_processes=args.num_processes,
                         num_frame_stack=args.num_frame_stack,
                         device=device)

    # number of frames ÷ number of policy steps before update ÷ number of cpu processes
    num_batch = args.num_processes * args.policy_num_steps
    num_updates = int(args.num_frames_r_aux) // num_batch

    # create list to store env observations
    obs_data = torch.zeros(num_updates * args.policy_num_steps + 1,
                           args.num_processes, *envs.observation_space.shape)
    # reset environments
    obs = envs.reset()  # obs.shape = (n_env,C,H,W)
    obs_data[0].copy_(obs)
    obs = obs.to(device)

    for iter_idx in range(num_updates):
        # rollout policy to collect num_batch of experience and store in storage
        for step in range(args.policy_num_steps):
            # sample actions from random agent
            action = torch.randint(0, envs.action_space.n,
                                   (args.num_processes, 1))
            # observe rewards and next obs
            obs, reward, done, infos = envs.step(action)
            # store obs
            obs_data[1 + iter_idx * args.policy_num_steps + step].copy_(obs)
    # close envs
    envs.close()

    # --- TRAIN R_aux (CB-VAE) ---
    # define CB-VAE where the encoder will be used as the auxiliary reward function R_aux
    print("Training R_aux Model.")

    # create dataloader for observations gathered
    obs_data = obs_data.reshape(-1, *envs.observation_space.shape)
    sampler = BatchSampler(SubsetRandomSampler(range(obs_data.size(0))),
                           args.cb_vae_batch_size,
                           drop_last=False)

    # initialise CB-VAE
    cb_vae = CBVAE(obs_shape=envs.observation_space.shape,
                   latent_dim=args.cb_vae_latent_dim).to(device)
    # optimiser
    optimiser = torch.optim.Adam(cb_vae.parameters(),
                                 lr=args.cb_vae_learning_rate)
    # put CB-VAE into train mode
    cb_vae.train()

    measures = defaultdict(list)
    for epoch in range(args.cb_vae_epochs):
        print("Epoch: ", epoch)
        start_time = time.time()
        batch_loss = 0
        for indices in sampler:
            obs = obs_data[indices].to(device)
            # zero accumulated gradients
            cb_vae.zero_grad()
            # forward pass through CB-VAE
            recon_batch, mu, log_var = cb_vae(obs)
            # calculate loss
            loss = cb_vae_loss(recon_batch, obs, mu, log_var)
            # backpropogation: calculating gradients
            loss.backward()
            # update parameters of generator
            optimiser.step()

            # save loss per mini-batch
            batch_loss += loss.item() * obs.size(0)
        # log losses per epoch
        wandb.log({
            'cb_vae/loss': batch_loss / obs_data.size(0),
            'cb_vae/time_taken': time.time() - start_time,
            'cb_vae/epoch': epoch
        })
    indices = np.random.randint(0, obs.size(0), args.cb_vae_num_samples**2)
    measures['true_images'].append(obs[indices].detach().cpu().numpy())
    measures['recon_images'].append(
        recon_batch[indices].detach().cpu().numpy())

    # plot ground truth images
    plt.rcParams.update({'font.size': 10})
    fig, axs = plt.subplots(args.cb_vae_num_samples,
                            args.cb_vae_num_samples,
                            figsize=(20, 20))
    for i, img in enumerate(measures['true_images'][0]):
        axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow(
            img.transpose(1, 2, 0))
        axs[i // args.cb_vae_num_samples][i %
                                          args.cb_vae_num_samples].axis('off')
    wandb.log({"Ground Truth Images": wandb.Image(plt)})
    # plot reconstructed images
    fig, axs = plt.subplots(args.cb_vae_num_samples,
                            args.cb_vae_num_samples,
                            figsize=(20, 20))
    for i, img in enumerate(measures['recon_images'][0]):
        axs[i // args.cb_vae_num_samples][i % args.cb_vae_num_samples].imshow(
            img.transpose(1, 2, 0))
        axs[i // args.cb_vae_num_samples][i %
                                          args.cb_vae_num_samples].axis('off')
    wandb.log({"Reconstructed Images": wandb.Image(plt)})

    # --- TRAIN Q_aux --
    # train PPO agent with value head replaced with action-value head and training on R_aux instead of the environment R
    print("Training Q_aux Model.")

    # initialise environments for training Q_aux
    envs = make_vec_envs(env_name=args.env_name,
                         start_level=0,
                         num_levels=0,
                         distribution_mode=args.distribution_mode,
                         paint_vel_info=args.paint_vel_info,
                         num_processes=args.num_processes,
                         num_frame_stack=args.num_frame_stack,
                         device=device)

    # initialise policy network
    actor_critic = QModel(obs_shape=envs.observation_space.shape,
                          action_space=envs.action_space,
                          hidden_size=args.hidden_size).to(device)

    # initialise policy trainer
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=envs.observation_space.shape,
                              action_space=envs.action_space)

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    update_start_time = time.time()
    # reset environments
    obs = envs.reset()  # obs.shape = (n_envs,C,H,W)
    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of cpu processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames_q_aux) // args.num_batch
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        # rollout policy to collect num_batch of experience and store in storage
        for step in range(args.policy_num_steps):

            with torch.no_grad():
                # sample actions from policy
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step])
                # obtain reward R_aux from encoder of CB-VAE
                r_aux, _, _ = cb_vae.encode(rollouts.obs[step])

            # observe rewards and next obs
            obs, _, done, infos = envs.step(action)

            # log episode info if episode finished
            for i, info in enumerate(infos):
                if 'episode' in info.keys():
                    episode_info_buf.append(info['episode'])
            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to policy buffer
            rollouts.insert(obs, r_aux, action, value, action_log_prob, masks)

            frames += args.num_processes

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy gradient algo
        total_loss, value_loss, action_loss, dist_entropy = policy.update(
            rollouts)

        # clean up after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:
            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            wandb.log({
                'q_aux_misc/timesteps':
                frames,
                'q_aux_misc/fps':
                fps,
                'q_aux_misc/explained_variance':
                float(ev),
                'q_aux_losses/total_loss':
                total_loss,
                'q_aux_losses/value_loss':
                value_loss,
                'q_aux_losses/action_loss':
                action_loss,
                'q_aux_losses/dist_entropy':
                dist_entropy,
                'q_aux_train/mean_episodic_return':
                utl_math.safe_mean(
                    [episode_info['r'] for episode_info in episode_info_buf]),
                'q_aux_train/mean_episodic_length':
                utl_math.safe_mean(
                    [episode_info['l'] for episode_info in episode_info_buf])
            })

    # close envs
    envs.close()

    # --- SAVE MODEL ---
    print("Saving Q_aux Model.")
    torch.save(actor_critic.state_dict(), args.q_aux_path)
Beispiel #3
0
def fit(policy,
        env,
        seed,
        nsteps=5,
        total_timesteps=int(80e6),
        vf_coef=0.5,
        ent_coef=0.01,
        max_grad_norm=0.5,
        lr=7e-4,
        lrschedule='linear',
        epsilon=1e-5,
        alpha=0.99,
        gamma=0.99,
        log_interval=100):

    set_global_seeds(seed)

    model = A2C(policy=policy,
                observation_space=env.observation_space,
                action_space=env.action_space,
                nenvs=env.num_envs,
                nsteps=nsteps,
                ent_coef=ent_coef,
                vf_coef=vf_coef,
                max_grad_norm=max_grad_norm,
                lr=lr,
                alpha=alpha,
                epsilon=epsilon,
                total_timesteps=total_timesteps,
                lrschedule=lrschedule)
    session = model.init_session()
    tf.global_variables_initializer().run(session=session)
    env_runner = Environment(env, model, nsteps=nsteps, gamma=gamma)

    nbatch = env.num_envs * nsteps
    tstart = time.time()
    writer = tf.summary.FileWriter('output', session.graph)
    for update in range(1, total_timesteps // nbatch + 1):
        tf.reset_default_graph()
        obs, states, rewards, masks, actions, values = env_runner.run(session)
        policy_loss, value_loss, policy_entropy = model.predict(
            observations=obs,
            states=states,
            rewards=rewards,
            masks=masks,
            actions=actions,
            values=values,
            session=session)
        nseconds = time.time() - tstart
        fps = int((update * nbatch) / nseconds)
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(values, rewards)
            logger.record_tabular("nupdates", update)
            logger.record_tabular("total_timesteps", update * nbatch)
            logger.record_tabular("fps", fps)
            logger.record_tabular("policy_entropy", float(policy_entropy))
            logger.record_tabular("value_loss", float(value_loss))
            logger.record_tabular("explained_variance", float(ev))
            logger.dump_tabular()
    env.close()
    writer.close()
    session.close()
Beispiel #4
0
def fit(
        policy,
        env,
        nsteps,
        total_timesteps,
        ent_coef,
        lr,
        vf_coef=0.5,
        max_grad_norm=0.5,
        gamma=0.99,
        lam=0.95,
        log_interval=10,
        nminibatches=4,
        noptepochs=4,
        cliprange=0.2,
        save_interval=0,
        load_path=None
):

    if isinstance(lr, float):
        lr = constfn(lr)
    else:
        assert callable(lr)
    if isinstance(cliprange, float):
        cliprange = constfn(cliprange)
    else:
        assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    nenvs = env.num_envs
    # nenvs = 8
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches

    model = PPO2(
        policy=policy,
        observation_space=ob_space,
        action_space=ac_space,
        nbatch_act=nenvs,
        nbatch_train=nbatch_train,
        nsteps=nsteps,
        ent_coef=ent_coef,
        vf_coef=vf_coef,
        max_grad_norm=max_grad_norm
    )
    Agent().init_vars()
    # if save_interval and logger.get_dir():
    #     import cloudpickle
    #     with open(os.path.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
    #         fh.write(cloudpickle.dumps(make_model))
    # model = make_model()
    # if load_path is not None:
    #     model.load(load_path)
    runner = Environment(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)

    epinfobuf = deque(maxlen=100)
    tfirststart = time.time()

    nupdates = total_timesteps//nbatch
    for update in range(1, nupdates+1):
        assert nbatch % nminibatches == 0
        nbatch_train = nbatch // nminibatches
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)
        obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run()
        epinfobuf.extend(epinfos)
        mblossvals = []
        if states is None:  # nonrecurrent version
            inds = np.arange(nbatch)
            for _ in range(noptepochs):
                np.random.shuffle(inds)
                for start in range(0, nbatch, nbatch_train):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (
                        arr[mbinds] for arr in (obs,
                                                returns,
                                                masks,
                                                actions,
                                                values,
                                                neglogpacs)
                    )
                    mblossvals.append(model.predict(lrnow, cliprangenow, *slices))
        else:  # recurrent version
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            envsperbatch = nbatch_train // nsteps
            for _ in range(noptepochs):
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = (
                        arr[mbflatinds] for arr in (obs,
                                                    returns,
                                                    masks,
                                                    actions,
                                                    values,
                                                    neglogpacs)
                    )
                    mbstates = states[mbenvinds]
                    mblossvals.append(model.predict(lrnow, cliprangenow, *slices, mbstates))

        lossvals = np.mean(mblossvals, axis=0)
        tnow = time.time()
        fps = int(nbatch / (tnow - tstart))
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(values, returns)
            logger.logkv("serial_timesteps", update*nsteps)
            logger.logkv("nupdates", update)
            logger.logkv("total_timesteps", update*nbatch)
            logger.logkv("fps", fps)
            logger.logkv("explained_variance", float(ev))
            logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
            logger.logkv('time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv(lossname, lossval)
            logger.dumpkvs()
        if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
            checkdir = os.path.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = os.path.join(checkdir, '%.5i' % update)
            print('Saving to', savepath)
            model.save(savepath)
    env.close()
    return model
Beispiel #5
0
    def rollouts(self):
        # Prepare for rollouts
        # ----------------------------------------
        seg_gen = self.traj_segment_generator(self.pi,
                                              self.env,
                                              self.timesteps_per_actorbatch,
                                              stochastic=True)

        episodes_so_far = 0
        timesteps_so_far = 0
        iters_so_far = 0
        tstart = time.time()
        lenbuffer = deque(maxlen=100)  # rolling buffer for episode lengths
        rewbuffer = deque(maxlen=100)  # rolling buffer for episode rewards

        assert sum([
            self.max_iters > 0, self.max_timesteps > 0, self.max_episodes > 0,
            self.max_seconds > 0
        ]) == 1, "Only one time constraint permitted"

        while True:
            if self.callback:
                self.callback(locals(), globals())
            if self.max_timesteps and timesteps_so_far >= self.max_timesteps:
                break
            elif self.max_episodes and episodes_so_far >= self.max_episodes:
                break
            elif self.max_iters and iters_so_far >= self.max_iters:
                break
            elif self.max_seconds and time.time() - tstart >= self.max_seconds:
                break

            if self.schedule == 'constant':
                cur_lrmult = 1.0
            elif self.schedule == 'linear':
                cur_lrmult = max(
                    1.0 - float(timesteps_so_far) / self.max_timesteps, 0)
            else:
                raise NotImplementedError

            logger.log("********** Iteration %i ************" % iters_so_far)

            seg = seg_gen.__next__()
            self.add_vtarg_and_adv(seg, self.gamma, self.lam)

            ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], \
                seg["tdlamret"]
            vpredbefore = seg[
                "vpred"]  # predicted value function before udpate
            atarg = (atarg - atarg.mean()) / atarg.std(
            )  # standardized advantage function estimate
            d = Dataset(dict(ob=ob, ac=ac, atarg=atarg, vtarg=tdlamret),
                        shuffle=not self.pi.recurrent)
            optim_batchsize = self.optim_batchsize or ob.shape[0]

            if hasattr(self.pi, "ob_rms"):
                self.pi.ob_rms.update(ob)  # update running mean/std for policy

            self.assign_old_eq_new(
            )  # set old parameter values to new parameter values
            logger.log("Optimizing...")
            logger.log(fmt_row(13, self.loss_names))
            # Here we do a bunch of optimization epochs over the data
            for _ in range(self.optim_epochs):
                losses = [
                ]  # list of tuples, each of which gives the loss for a minibatch
                for batch in d.iterate_once(optim_batchsize):
                    *newlosses, g = self.lossandgrad(batch["ob"], batch["ac"],
                                                     batch["atarg"],
                                                     batch["vtarg"],
                                                     cur_lrmult)
                    self.adam.update(g, self.optim_stepsize * cur_lrmult)
                    losses.append(newlosses)
                logger.log(fmt_row(13, np.mean(losses, axis=0)))

            logger.log("Evaluating losses...")
            losses = []
            for batch in d.iterate_once(optim_batchsize):
                newlosses = self.loss(batch["ob"], batch["ac"], batch["atarg"],
                                      batch["vtarg"], cur_lrmult)
                losses.append(newlosses)
            meanlosses, _, _ = mpi_moments(losses, axis=0)
            logger.log(fmt_row(13, meanlosses))
            for (lossval, name) in zipsame(meanlosses, self.loss_names):
                logger.record_tabular("loss_" + name, lossval)
            logger.record_tabular("ev_tdlam_before",
                                  explained_variance(vpredbefore, tdlamret))
            lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
            listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
            lens, rews = map(self.flatten_lists, zip(*listoflrpairs))
            lenbuffer.extend(lens)
            rewbuffer.extend(rews)
            logger.record_tabular("EpLenMean", np.mean(lenbuffer))
            logger.record_tabular("EpRewMean", np.mean(rewbuffer))
            logger.record_tabular("EpThisIter", len(lens))
            episodes_so_far += len(lens)
            timesteps_so_far += sum(lens)
            iters_so_far += 1
            logger.record_tabular("EpisodesSoFar", episodes_so_far)
            logger.record_tabular("TimestepsSoFar", timesteps_so_far)
            logger.record_tabular("TimeElapsed", time.time() - tstart)
            if MPI.COMM_WORLD.Get_rank() == 0:
                logger.dump_tabular()

        return self.pi
Beispiel #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument('--model',
                        type=str,
                        default='ppo',
                        help='the model to use for training. {ppo, ppo_aup}')
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # --- ARGUMENTS ---

    if model == 'ppo':
        args = args_ppo.get_args(rest_args)
    elif model == 'ppo_aup':
        args = args_ppo_aup.get_args(rest_args)
    else:
        raise NotImplementedError

    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model

    # warnings
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError(
                'If you want fully deterministic code, run it with num_processes=1.'
                'Warning: This will slow things down and might break A2C if '
                'policy_num_steps < env._max_episode_steps.')

    # --- TRAINING ---
    print("Setting up wandb logging.")

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(project=args.proj_name,
               name=args.run_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    # make directory for saving models
    save_dir = os.path.join(wandb.run.dir, 'models')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    print("Setting up Environments.")
    # initialise environments for training
    train_envs = make_vec_envs(env_name=args.env_name,
                               start_level=args.train_start_level,
                               num_levels=args.train_num_levels,
                               distribution_mode=args.distribution_mode,
                               paint_vel_info=args.paint_vel_info,
                               num_processes=args.num_processes,
                               num_frame_stack=args.num_frame_stack,
                               device=device)
    # initialise environments for evaluation
    eval_envs = make_vec_envs(env_name=args.env_name,
                              start_level=0,
                              num_levels=0,
                              distribution_mode=args.distribution_mode,
                              paint_vel_info=args.paint_vel_info,
                              num_processes=args.num_processes,
                              num_frame_stack=args.num_frame_stack,
                              device=device)
    _ = eval_envs.reset()

    print("Setting up Actor-Critic model and Training algorithm.")
    # initialise policy network
    actor_critic = ACModel(obs_shape=train_envs.observation_space.shape,
                           action_space=train_envs.action_space,
                           hidden_size=args.hidden_size).to(device)

    # initialise policy training algorithm
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy training algorithm
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=train_envs.observation_space.shape,
                              action_space=train_envs.action_space)

    # initialise Q_aux function(s) for AUP
    if args.use_aup:
        print("Initialising Q_aux models.")
        q_aux = [
            QModel(obs_shape=train_envs.observation_space.shape,
                   action_space=train_envs.action_space,
                   hidden_size=args.hidden_size).to(device)
            for _ in range(args.num_q_aux)
        ]
        if args.num_q_aux == 1:
            # load weights to model
            path = args.q_aux_dir + "0.pt"
            q_aux[0].load_state_dict(torch.load(path))
            q_aux[0].eval()
        else:
            # get max number of q_aux functions to choose from
            args.max_num_q_aux = os.listdir(args.q_aux_dir)
            q_aux_models = random.sample(list(range(0, args.max_num_q_aux)),
                                         args.num_q_aux)
            # load weights to models
            for i, model in enumerate(q_aux):
                path = args.q_aux_dir + str(q_aux_models[i]) + ".pt"
                model.load_state_dict(torch.load(path))
                model.eval()

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    # update wandb args
    wandb.config.update(args)

    update_start_time = time.time()
    # reset environments
    obs = train_envs.reset()  # obs.shape = (num_processes,C,H,W)
    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames) // args.num_batch

    # define AUP coefficient
    if args.use_aup:
        aup_coef = args.aup_coef_start
        aup_linear_increase_val = math.exp(
            math.log(args.aup_coef_end / args.aup_coef_start) /
            args.num_updates)

    print("Training beginning.")
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        if args.use_aup:
            aup_measures = defaultdict(list)

        # rollout policy to collect num_batch of experience and place in storage
        for step in range(args.policy_num_steps):

            # sample actions from policy
            with torch.no_grad():
                value, action, action_log_prob = actor_critic.act(
                    rollouts.obs[step])

            # observe rewards and next obs
            obs, reward, done, infos = train_envs.step(action)

            # calculate AUP reward
            if args.use_aup:
                intrinsic_reward = torch.zeros_like(reward)
                with torch.no_grad():
                    for model in q_aux:
                        # get action-values
                        action_values = model.get_action_value(
                            rollouts.obs[step])
                        # get action-value for action taken
                        action_value = torch.sum(
                            action_values * torch.nn.functional.one_hot(
                                action,
                                num_classes=train_envs.action_space.n).squeeze(
                                    dim=1),
                            dim=1)
                        # calculate the penalty
                        intrinsic_reward += torch.abs(
                            action_value.unsqueeze(dim=1) -
                            action_values[:, 4].unsqueeze(dim=1))
                intrinsic_reward /= args.num_q_aux
                # add intrinsic reward to the extrinsic reward
                reward -= aup_coef * intrinsic_reward
                # log the intrinsic reward from the first env.
                aup_measures['intrinsic_reward'].append(aup_coef *
                                                        intrinsic_reward[0, 0])
                if done[0] and infos[0]['prev_level_complete'] == 1:
                    aup_measures['episode_complete'].append(2)
                elif done[0] and infos[0]['prev_level_complete'] == 0:
                    aup_measures['episode_complete'].append(1)
                else:
                    aup_measures['episode_complete'].append(0)

            # log episode info if episode finished
            for info in infos:
                if 'episode' in info.keys():
                    episode_info_buf.append(info['episode'])

            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to storage
            rollouts.insert(obs, reward, action, value, action_log_prob, masks)

            frames += args.num_processes

        # linearly increase aup coefficient after every update
        if args.use_aup:
            aup_coef *= aup_linear_increase_val

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy training algorithm
        total_loss, value_loss, action_loss, dist_entropy = policy.update(
            rollouts)

        # clean up storage after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:

            # --- EVALUATION ---
            eval_episode_info_buf = utl_eval.evaluate(
                eval_envs=eval_envs, actor_critic=actor_critic, device=device)

            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # calculates whether the value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            if args.use_aup:
                step = frames - args.num_processes * args.policy_num_steps
                for i in range(args.policy_num_steps):
                    wandb.log(
                        {
                            'aup/intrinsic_reward':
                            aup_measures['intrinsic_reward'][i],
                            'aup/episode_complete':
                            aup_measures['episode_complete'][i]
                        },
                        step=step)
                    step += args.num_processes

            wandb.log(
                {
                    'misc/timesteps':
                    frames,
                    'misc/fps':
                    fps,
                    'misc/explained_variance':
                    float(ev),
                    'losses/total_loss':
                    total_loss,
                    'losses/value_loss':
                    value_loss,
                    'losses/action_loss':
                    action_loss,
                    'losses/dist_entropy':
                    dist_entropy,
                    'train/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r'] for episode_info in episode_info_buf
                    ]),
                    'train/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l'] for episode_info in episode_info_buf
                    ]),
                    'eval/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in eval_episode_info_buf
                    ]),
                    'eval/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in eval_episode_info_buf
                    ])
                },
                step=frames)

        # --- SAVE MODEL ---

        # save for every interval-th episode or for the last epoch
        if iter_idx != 0 and (iter_idx % args.save_interval == 0
                              or iter_idx == args.num_updates - 1):
            print("Saving Actor-Critic Model.")
            torch.save(actor_critic.state_dict(),
                       os.path.join(save_dir, "policy{0}.pt".format(iter_idx)))

    # close envs
    train_envs.close()
    eval_envs.close()

    # --- TEST ---

    if args.test:
        print("Testing beginning.")
        episodic_return = utl_test.test(args=args,
                                        actor_critic=actor_critic,
                                        device=device)

        # save returns from train and test levels to analyse using interactive mode
        train_levels = torch.arange(
            args.train_start_level,
            args.train_start_level + args.train_num_levels)
        for i, level in enumerate(train_levels):
            wandb.log({
                'test/train_levels': level,
                'test/train_returns': episodic_return[0][i]
            })
        test_levels = torch.arange(
            args.test_start_level,
            args.test_start_level + args.test_num_levels)
        for i, level in enumerate(test_levels):
            wandb.log({
                'test/test_levels': level,
                'test/test_returns': episodic_return[1][i]
            })
        # log returns from test envs
        wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[0])
        wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[1])
Beispiel #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env_name',
                        type=str,
                        default='coinrun',
                        help='name of the environment to train on.')
    parser.add_argument(
        '--model',
        type=str,
        default='ppo',
        help='the model to use for training. {ppo, ibac, ibac_sni, dist_match}'
    )
    args, rest_args = parser.parse_known_args()
    env_name = args.env_name
    model = args.model

    # --- ARGUMENTS ---

    if model == 'ppo':
        args = args_ppo.get_args(rest_args)
    elif model == 'ibac':
        args = args_ibac.get_args(rest_args)
    elif model == 'ibac_sni':
        args = args_ibac_sni.get_args(rest_args)
    elif model == 'dist_match':
        args = args_dist_match.get_args(rest_args)
    else:
        raise NotImplementedError

    # place other args back into argparse.Namespace
    args.env_name = env_name
    args.model = model
    args.num_train_envs = args.num_processes - args.num_val_envs if args.num_val_envs > 0 else args.num_processes

    # warnings
    if args.deterministic_execution:
        print('Envoking deterministic code execution.')
        if torch.backends.cudnn.enabled:
            warnings.warn('Running with deterministic CUDNN.')
        if args.num_processes > 1:
            raise RuntimeError(
                'If you want fully deterministic code, run it with num_processes=1.'
                'Warning: This will slow things down and might break A2C if '
                'policy_num_steps < env._max_episode_steps.')

    elif args.num_val_envs > 0 and (args.num_val_envs >= args.num_processes
                                    or not args.percentage_levels_train < 1.0):
        raise ValueError(
            'If --args.num_val_envs>0 then you must also have'
            '--num_val_envs < --num_processes and  0 < --percentage_levels_train < 1.'
        )

    elif args.num_val_envs > 0 and not args.use_dist_matching and args.dist_matching_coef != 0:
        raise ValueError(
            'If --num_val_envs>0 and --use_dist_matching=False then you must also have'
            '--dist_matching_coef=0.')

    elif args.use_dist_matching and not args.num_val_envs > 0:
        raise ValueError(
            'If --use_dist_matching=True then you must also have'
            '0 < --num_val_envs < --num_processes and 0 < --percentage_levels_train < 1.'
        )

    elif args.analyse_rep and not args.use_bottleneck:
        raise ValueError('If --analyse_rep=True then you must also have'
                         '--use_bottleneck=True.')

    # --- TRAINING ---
    print("Setting up wandb logging.")

    # Weights & Biases logger
    if args.run_name is None:
        # make run name as {env_name}_{TIME}
        now = datetime.datetime.now().strftime('_%d-%m_%H:%M:%S')
        args.run_name = args.env_name + '_' + args.algo + now
    # initialise wandb
    wandb.init(project=args.proj_name,
               name=args.run_name,
               group=args.group_name,
               config=args,
               monitor_gym=False)
    # save wandb dir path
    args.run_dir = wandb.run.dir
    # make directory for saving models
    save_dir = os.path.join(wandb.run.dir, 'models')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # set random seed of random, torch and numpy
    utl.set_global_seed(args.seed, args.deterministic_execution)

    # initialise environments for training
    print("Setting up Environments.")
    if args.num_val_envs > 0:
        train_num_levels = int(args.train_num_levels *
                               args.percentage_levels_train)
        val_start_level = args.train_start_level + train_num_levels
        val_num_levels = args.train_num_levels - train_num_levels
        train_envs = make_vec_envs(env_name=args.env_name,
                                   start_level=args.train_start_level,
                                   num_levels=train_num_levels,
                                   distribution_mode=args.distribution_mode,
                                   paint_vel_info=args.paint_vel_info,
                                   num_processes=args.num_train_envs,
                                   num_frame_stack=args.num_frame_stack,
                                   device=device)
        val_envs = make_vec_envs(env_name=args.env_name,
                                 start_level=val_start_level,
                                 num_levels=val_num_levels,
                                 distribution_mode=args.distribution_mode,
                                 paint_vel_info=args.paint_vel_info,
                                 num_processes=args.num_val_envs,
                                 num_frame_stack=args.num_frame_stack,
                                 device=device)
    else:
        train_envs = make_vec_envs(env_name=args.env_name,
                                   start_level=args.train_start_level,
                                   num_levels=args.train_num_levels,
                                   distribution_mode=args.distribution_mode,
                                   paint_vel_info=args.paint_vel_info,
                                   num_processes=args.num_processes,
                                   num_frame_stack=args.num_frame_stack,
                                   device=device)
    # initialise environments for evaluation
    eval_envs = make_vec_envs(env_name=args.env_name,
                              start_level=0,
                              num_levels=0,
                              distribution_mode=args.distribution_mode,
                              paint_vel_info=args.paint_vel_info,
                              num_processes=args.num_processes,
                              num_frame_stack=args.num_frame_stack,
                              device=device)
    _ = eval_envs.reset()
    # initialise environments for analysing the representation
    if args.analyse_rep:
        analyse_rep_train1_envs, analyse_rep_train2_envs, analyse_rep_val_envs, analyse_rep_test_envs = make_rep_analysis_envs(
            args, device)

    print("Setting up Actor-Critic model and Training algorithm.")
    # initialise policy network
    actor_critic = ACModel(obs_shape=train_envs.observation_space.shape,
                           action_space=train_envs.action_space,
                           hidden_size=args.hidden_size,
                           use_bottleneck=args.use_bottleneck,
                           sni_type=args.sni_type).to(device)

    # initialise policy training algorithm
    if args.algo == 'ppo':
        policy = PPO(actor_critic=actor_critic,
                     ppo_epoch=args.policy_ppo_epoch,
                     num_mini_batch=args.policy_num_mini_batch,
                     clip_param=args.policy_clip_param,
                     value_loss_coef=args.policy_value_loss_coef,
                     entropy_coef=args.policy_entropy_coef,
                     max_grad_norm=args.policy_max_grad_norm,
                     lr=args.policy_lr,
                     eps=args.policy_eps,
                     vib_coef=args.vib_coef,
                     sni_coef=args.sni_coef,
                     use_dist_matching=args.use_dist_matching,
                     dist_matching_loss=args.dist_matching_loss,
                     dist_matching_coef=args.dist_matching_coef,
                     num_train_envs=args.num_train_envs,
                     num_val_envs=args.num_val_envs)
    else:
        raise NotImplementedError

    # initialise rollout storage for the policy training algorithm
    rollouts = RolloutStorage(num_steps=args.policy_num_steps,
                              num_processes=args.num_processes,
                              obs_shape=train_envs.observation_space.shape,
                              action_space=train_envs.action_space)

    # count number of frames and updates
    frames = 0
    iter_idx = 0

    # update wandb args
    wandb.config.update(args)
    # wandb.watch(actor_critic, log="all") # to log gradients of actor-critic network

    update_start_time = time.time()

    # reset environments
    if args.num_val_envs > 0:
        obs = torch.cat([train_envs.reset(),
                         val_envs.reset()])  # obs.shape = (n_envs,C,H,W)
    else:
        obs = train_envs.reset()  # obs.shape = (n_envs,C,H,W)

    # insert initial observation to rollout storage
    rollouts.obs[0].copy_(obs)
    rollouts.to(device)

    # initialise buffer for calculating mean episodic returns
    train_episode_info_buf = deque(maxlen=10)
    val_episode_info_buf = deque(maxlen=10)

    # calculate number of updates
    # number of frames ÷ number of policy steps before update ÷ number of processes
    args.num_batch = args.num_processes * args.policy_num_steps
    args.num_updates = int(args.num_frames) // args.num_batch
    print("Training beginning.")
    print("Number of updates: ", args.num_updates)
    for iter_idx in range(args.num_updates):
        print("Iter: ", iter_idx)

        # put actor-critic into train mode
        actor_critic.train()

        # rollout policy to collect num_batch of experience and place in storage
        for step in range(args.policy_num_steps):

            # sample actions from policy
            with torch.no_grad():
                value, action, action_log_prob, _ = actor_critic.act(
                    rollouts.obs[step])

            # observe rewards and next obs
            if args.num_val_envs > 0:
                obs, reward, done, infos = train_envs.step(
                    action[:args.num_train_envs, :])
                val_obs, val_reward, val_done, val_infos = val_envs.step(
                    action[args.num_train_envs:, :])
                obs = torch.cat([obs, val_obs])
                reward = torch.cat([reward, val_reward])
                done, val_done = list(done), list(val_done)
                done.extend(val_done)
                infos.extend(val_infos)
            else:
                obs, reward, done, infos = train_envs.step(action)

            # log episode info if episode finished
            for i, info in enumerate(infos):
                if i < args.num_train_envs and 'episode' in info.keys():
                    train_episode_info_buf.append(info['episode'])
                elif i >= args.num_train_envs and 'episode' in info.keys():
                    val_episode_info_buf.append(info['episode'])

            # create mask for episode ends
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done]).to(device)

            # add experience to storage
            rollouts.insert(obs, reward, action, value, action_log_prob, masks)

            frames += args.num_processes

        # --- UPDATE ---

        # bootstrap next value prediction
        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.obs[-1]).detach()

        # compute returns for current rollouts
        rollouts.compute_returns(next_value, args.policy_gamma,
                                 args.policy_gae_lambda)

        # update actor-critic using policy gradient algo
        total_loss, value_loss, action_loss, dist_entropy, vib_kl, dist_matching_loss = policy.update(
            rollouts)

        # clean up storage after update
        rollouts.after_update()

        # --- LOGGING ---

        if iter_idx % args.log_interval == 0 or iter_idx == args.num_updates - 1:

            # --- EVALUATION ---
            eval_episode_info_buf = utl_eval.evaluate(
                eval_envs=eval_envs, actor_critic=actor_critic, device=device)

            # --- ANALYSE REPRESENTATION ---
            if args.analyse_rep:
                rep_measures = utl_rep.analyse_rep(
                    args=args,
                    train1_envs=analyse_rep_train1_envs,
                    train2_envs=analyse_rep_train2_envs,
                    val_envs=analyse_rep_val_envs,
                    test_envs=analyse_rep_test_envs,
                    actor_critic=actor_critic,
                    device=device)

            # get stats for run
            update_end_time = time.time()
            num_interval_updates = 1 if iter_idx == 0 else args.log_interval
            fps = num_interval_updates * (
                args.num_processes *
                args.policy_num_steps) / (update_end_time - update_start_time)
            update_start_time = update_end_time
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = utl_math.explained_variance(utl.sf01(rollouts.value_preds),
                                             utl.sf01(rollouts.returns))

            wandb.log(
                {
                    'misc/timesteps':
                    frames,
                    'misc/fps':
                    fps,
                    'misc/explained_variance':
                    float(ev),
                    'losses/total_loss':
                    total_loss,
                    'losses/value_loss':
                    value_loss,
                    'losses/action_loss':
                    action_loss,
                    'losses/dist_entropy':
                    dist_entropy,
                    'train/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in train_episode_info_buf
                    ]),
                    'train/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in train_episode_info_buf
                    ]),
                    'eval/mean_episodic_return':
                    utl_math.safe_mean([
                        episode_info['r']
                        for episode_info in eval_episode_info_buf
                    ]),
                    'eval/mean_episodic_length':
                    utl_math.safe_mean([
                        episode_info['l']
                        for episode_info in eval_episode_info_buf
                    ])
                },
                step=iter_idx)
            if args.use_bottleneck:
                wandb.log({'losses/vib_kl': vib_kl}, step=iter_idx)
            if args.num_val_envs > 0:
                wandb.log(
                    {
                        'losses/dist_matching_loss':
                        dist_matching_loss,
                        'val/mean_episodic_return':
                        utl_math.safe_mean([
                            episode_info['r']
                            for episode_info in val_episode_info_buf
                        ]),
                        'val/mean_episodic_length':
                        utl_math.safe_mean([
                            episode_info['l']
                            for episode_info in val_episode_info_buf
                        ])
                    },
                    step=iter_idx)
            if args.analyse_rep:
                wandb.log(
                    {
                        "analysis/" + key: val
                        for key, val in rep_measures.items()
                    },
                    step=iter_idx)

        # --- SAVE MODEL ---

        # save for every interval-th episode or for the last epoch
        if iter_idx != 0 and (iter_idx % args.save_interval == 0
                              or iter_idx == args.num_updates - 1):
            print("Saving Actor-Critic Model.")
            torch.save(actor_critic.state_dict(),
                       os.path.join(save_dir, "policy{0}.pt".format(iter_idx)))

    # close envs
    train_envs.close()
    eval_envs.close()

    # --- TEST ---

    if args.test:
        print("Testing beginning.")
        episodic_return, latents_z = utl_test.test(args=args,
                                                   actor_critic=actor_critic,
                                                   device=device)

        # save returns from train and test levels to analyse using interactive mode
        train_levels = torch.arange(
            args.train_start_level,
            args.train_start_level + args.train_num_levels)
        for i, level in enumerate(train_levels):
            wandb.log({
                'test/train_levels': level,
                'test/train_returns': episodic_return[0][i]
            })
        test_levels = torch.arange(
            args.test_start_level,
            args.test_start_level + args.test_num_levels)
        for i, level in enumerate(test_levels):
            wandb.log({
                'test/test_levels': level,
                'test/test_returns': episodic_return[1][i]
            })
        # log returns from test envs
        wandb.run.summary["train_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[0])
        wandb.run.summary["test_mean_episodic_return"] = utl_math.safe_mean(
            episodic_return[1])

        # plot latent representation
        if args.plot_pca:
            print("Plotting PCA of Latent Representation.")
            utl_rep.pca(args, latents_z)
Beispiel #8
0
def fit(
        model,
        env,
        timesteps_per_batch,  # what to train on
        max_kl,
        cg_iters,
        gamma,
        lam,  # advantage estimation
        entcoeff=0.0,
        cg_damping=1e-2,
        vf_stepsize=3e-4,
        vf_iters=3,
        max_timesteps=0,
        max_episodes=0,
        max_iters=0,  # time constraint
        callback=None):
    # Setup losses and stuff
    # ----------------------------------------
    # nworkers = MPI.COMM_WORLD.Get_size()
    # rank = MPI.COMM_WORLD.Get_rank()
    np.set_printoptions(precision=3)

    th_init = model.get_flat()
    MPI.COMM_WORLD.Bcast(th_init, root=0)
    model.set_from_flat(th_init)
    model.vfadam.sync()
    print("Init param sum", th_init.sum(), flush=True)

    # Prepare for rollouts
    # ----------------------------------------
    seg_gen = model.traj_segment_generator(model.pi,
                                           env,
                                           timesteps_per_batch,
                                           stochastic=True)

    episodes_so_far = 0
    timesteps_so_far = 0
    iters_so_far = 0
    tstart = time.time()
    lenbuffer = deque(maxlen=40)  # rolling buffer for episode lengths
    rewbuffer = deque(maxlen=40)  # rolling buffer for episode rewards

    assert sum([max_iters > 0, max_timesteps > 0, max_episodes > 0]) == 1

    while True:
        if callback:
            callback(locals(), globals())
        if max_timesteps and timesteps_so_far >= max_timesteps:
            break
        elif max_episodes and episodes_so_far >= max_episodes:
            break
        elif max_iters and iters_so_far >= max_iters:
            break
        logger.log("********** Iteration %i ************" % iters_so_far)

        with model.timed("sampling"):
            seg = seg_gen.__next__()

        model.add_vtarg_and_adv(seg, gamma, lam)

        # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets))
        ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg[
            "tdlamret"]
        vpredbefore = seg["vpred"]  # predicted value function before udpate
        atarg = (atarg - atarg.mean()
                 ) / atarg.std()  # standardized advantage function estimate

        if hasattr(model.pi, "ret_rms"):
            model.pi.ret_rms.update(tdlamret)
        if hasattr(model.pi, "ob_rms"):
            model.pi.ob_rms.update(ob)  # update running mean/std for policy

        args = seg["ob"], seg["ac"], atarg
        fvpargs = [arr[::5] for arr in args]

        def fisher_vector_product(p):
            return model.allmean(model.compute_fvp(p, *
                                                   fvpargs)) + cg_damping * p

        model.assign_old_eq_new(
        )  # set old parameter values to new parameter values
        with model.timed("computegrad"):
            *lossbefore, g = model.compute_lossandgrad(*args)
        lossbefore = model.allmean(np.array(lossbefore))
        g = model.allmean(g)
        if np.allclose(g, 0):
            logger.log("Got zero gradient. not updating")
        else:
            with model.timed("cg"):
                stepdir = conjugate_gradient(fisher_vector_product,
                                             g,
                                             cg_iters=cg_iters,
                                             verbose=model.rank == 0)
            assert np.isfinite(stepdir).all()
            shs = .5 * stepdir.dot(fisher_vector_product(stepdir))
            lm = np.sqrt(shs / max_kl)
            # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g))
            fullstep = stepdir / lm
            expectedimprove = g.dot(fullstep)
            surrbefore = lossbefore[0]
            stepsize = 1.0
            thbefore = model.get_flat()
            for _ in range(10):
                thnew = thbefore + fullstep * stepsize
                model.set_from_flat(thnew)
                meanlosses = surr, kl, *_ = model.allmean(
                    np.array(model.compute_losses(*args)))
                improve = surr - surrbefore
                logger.log("Expected: %.3f Actual: %.3f" %
                           (expectedimprove, improve))
                if not np.isfinite(meanlosses).all():
                    logger.log("Got non-finite value of losses -- bad!")
                elif kl > max_kl * 1.5:
                    logger.log("violated KL constraint. shrinking step.")
                elif improve < 0:
                    logger.log("surrogate didn't improve. shrinking step.")
                else:
                    logger.log("Stepsize OK!")
                    break
                stepsize *= .5
            else:
                logger.log("couldn't compute a good step")
                model.set_from_flat(thbefore)
            if model.nworkers > 1 and iters_so_far % 20 == 0:
                paramsums = MPI.COMM_WORLD.allgather(
                    (thnew.sum(),
                     model.vfadam.getflat().sum()))  # list of tuples
                assert all(
                    np.allclose(ps, paramsums[0]) for ps in paramsums[1:])

        for (lossname, lossval) in zip(model.loss_names, meanlosses):
            logger.record_tabular(lossname, lossval)

        with model.timed("vf"):

            for _ in range(vf_iters):
                for (mbob, mbret) in dataset.iterbatches(
                    (seg["ob"], seg["tdlamret"]),
                        include_final_partial_batch=False,
                        batch_size=64):
                    g = model.allmean(model.compute_vflossandgrad(mbob, mbret))
                    model.vfadam.update(g, vf_stepsize)

        logger.record_tabular("ev_tdlam_before",
                              explained_variance(vpredbefore, tdlamret))

        lrlocal = (seg["ep_lens"], seg["ep_rets"])  # local values
        listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal)  # list of tuples
        lens, rews = map(model.flatten_lists, zip(*listoflrpairs))
        lenbuffer.extend(lens)
        rewbuffer.extend(rews)

        logger.record_tabular("EpLenMean", np.mean(lenbuffer))
        logger.record_tabular("EpRewMean", np.mean(rewbuffer))
        logger.record_tabular("EpThisIter", len(lens))
        episodes_so_far += len(lens)
        timesteps_so_far += sum(lens)
        iters_so_far += 1

        logger.record_tabular("EpisodesSoFar", episodes_so_far)
        logger.record_tabular("TimestepsSoFar", timesteps_so_far)
        logger.record_tabular("TimeElapsed", time.time() - tstart)

        if model.rank == 0:
            logger.dump_tabular()
Beispiel #9
0
def fit(policy,
        env,
        seed,
        total_timesteps=int(40e6),
        gamma=0.99,
        log_interval=1,
        nprocs=32,
        nsteps=20,
        ent_coef=0.01,
        vf_coef=0.5,
        vf_fisher_coef=1.0,
        lr=0.25,
        max_grad_norm=0.5,
        kfac_clip=0.001,
        save_interval=None,
        lrschedule='linear'):
    tf.reset_default_graph()
    set_global_seeds(seed)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    model = AcktrDiscrete(policy,
                          ob_space,
                          ac_space,
                          nenvs,
                          total_timesteps,
                          nsteps=nsteps,
                          ent_coef=ent_coef,
                          vf_coef=vf_fisher_coef,
                          lr=lr,
                          max_grad_norm=max_grad_norm,
                          kfac_clip=kfac_clip,
                          lrschedule=lrschedule)
    # if save_interval and logger.get_dir():
    #     import cloudpickle
    #     with open(os.path.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
    #         fh.write(cloudpickle.dumps(make_model))
    # model = make_model()

    runner = Environment(env, model, nsteps=nsteps, gamma=gamma)
    nbatch = nenvs * nsteps
    tstart = time.time()
    coord = tf.train.Coordinator()
    enqueue_threads = model.q_runner.create_threads(model.sess,
                                                    coord=coord,
                                                    start=True)
    for update in range(1, total_timesteps // nbatch + 1):
        obs, states, rewards, masks, actions, values = runner.run()
        policy_loss, value_loss, policy_entropy = model.train(
            obs, states, rewards, masks, actions, values)
        model.old_obs = obs
        nseconds = time.time() - tstart
        fps = int((update * nbatch) / nseconds)
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(values, rewards)
            logger.record_tabular("nupdates", update)
            logger.record_tabular("total_timesteps", update * nbatch)
            logger.record_tabular("fps", fps)
            logger.record_tabular("policy_entropy", float(policy_entropy))
            logger.record_tabular("policy_loss", float(policy_loss))
            logger.record_tabular("value_loss", float(value_loss))
            logger.record_tabular("explained_variance", float(ev))
            logger.dump_tabular()

        if save_interval and (update % save_interval == 0 or update == 1) \
           and logger.get_dir():
            savepath = os.path.join(logger.get_dir(),
                                    'checkpoint%.5i' % update)
            print('Saving to', savepath)
            model.save(savepath)
    coord.request_stop()
    coord.join(enqueue_threads)
    env.close()