Beispiel #1
0
    def _write_evaluation(self, d):
        logger.logkvs(d)
        logger.dumpkvs()
        try:
            with open(os.path.join(self.logdir, 'eval.pkl'), 'rb') as f:
                d_ = pkl.load(f)
        except FileNotFoundError:
            d_ = {}

        for k in d_.keys():
            if k not in d:
                d[k] = d_[k]

        with open(os.path.join(self.logdir, 'eval.pkl'), 'wb') as f:
            pkl.dump(d, f)
Beispiel #2
0
    def train(self):
        """
        Implement the repilte algorithm for ppo reinforcement learning
        """
        start_time = time.time()
        avg_ret = []
        avg_pg_loss = []
        avg_vf_loss = []

        avg_latencies = []
        for itr in range(self.start_itr, self.n_itr):
            itr_start_time = time.time()
            logger.log("\n ---------------- Iteration %d ----------------" %
                       itr)
            logger.log("Sampling set of tasks/goals for this meta-batch...")

            paths = self.sampler.obtain_samples(log=True, log_prefix='')
            """ ----------------- Processing Samples ---------------------"""
            logger.log("Processing samples...")
            samples_data = self.sampler_processor.process_samples(
                paths, log='all', log_prefix='')
            """ ------------------- Inner Policy Update --------------------"""
            policy_losses, value_losses = self.algo.UpdatePPOTarget(
                samples_data, batch_size=self.batch_size)

            #print("task losses: ", losses)
            print("average policy losses: ", np.mean(policy_losses))
            avg_pg_loss.append(np.mean(policy_losses))

            print("average value losses: ", np.mean(value_losses))
            avg_vf_loss.append(np.mean(value_losses))
            """ ------------------- Logging Stuff --------------------------"""

            ret = np.sum(samples_data['rewards'], axis=-1)
            avg_reward = np.mean(ret)

            latency = samples_data['finish_time']
            avg_latency = np.mean(latency)

            avg_latencies.append(avg_latency)

            logger.logkv('Itr', itr)
            logger.logkv('Average reward, ', avg_reward)
            logger.logkv('Average latency,', avg_latency)
            logger.dumpkvs()
            avg_ret.append(avg_reward)

        return avg_ret, avg_pg_loss, avg_vf_loss, avg_latencies
Beispiel #3
0
def main(args):
    continuous_actions = (args.env_name in [
        'AntVel-v1', 'AntDir-v1', 'AntPos-v0', 'HalfCheetahVel-v1',
        'HalfCheetahDir-v1', '2DNavigation-v0', 'Point2DWalls-corner-v0',
        'Ant-v0', 'HalfCheetah-v0'
    ])

    logger.configure(dir=args.log_dir, format_strs=['stdout', 'log', 'csv'])
    logger.log(args)
    json.dump(vars(args),
              open(os.path.join(
                  args.log_dir,
                  'params.json',
              ), 'w'),
              indent=2)

    sampler = BatchSamplerMultiworld(args)
    sampler_val = BatchSamplerMultiworld(args, val=True)

    if continuous_actions:
        policy = NormalMLPPolicy(
            int(np.prod(sampler.envs.observation_space.shape)),
            int(np.prod(sampler.envs.action_space.shape)),
            hidden_sizes=(args.hidden_size, ) * args.num_layers,
            bias_transformation_size=args.bias_transformation_size,
            init_gain=args.init_gain,
        )
    else:
        raise NotImplementedError
    baseline = LinearFeatureBaseline(
        int(np.prod(sampler.envs.observation_space.shape)))

    metalearner = MetaLearner(sampler,
                              policy,
                              baseline,
                              gamma=args.gamma,
                              fast_lr=args.fast_lr,
                              tau=args.tau,
                              entropy_coef=args.entropy_coef,
                              device=args.device)

    start_time = time.time()

    processes = []

    for batch in range(args.num_batches):
        metalearner.reset()
        tasks = sampler.sample_tasks(num_tasks=args.meta_batch_size)
        episodes = metalearner.sample(tasks, first_order=args.first_order)
        if sampler.rewarder.fit_counter > 0:
            metalearner.step(episodes,
                             max_kl=args.max_kl,
                             cg_iters=args.cg_iters,
                             cg_damping=args.cg_damping,
                             ls_max_steps=args.ls_max_steps,
                             ls_backtrack_ratio=args.ls_backtrack_ratio)

        if batch % args.rewarder_fit_period == 0:
            sampler.fit_rewarder(logger)

        if args.rewarder == 'unsupervised':
            sampler.log_unsupervised(logger)
        log_main(logger, episodes, batch, args, start_time, metalearner)

        if batch % args.save_period == 0 or batch == args.num_batches - 1:
            save_model_maml(args, policy, batch)

        if batch % args.val_period == 0 or batch == args.num_batches - 1:
            val(args, sampler_val, policy, baseline, batch)

        if batch % args.vis_period == 0 or batch == args.num_batches - 1:
            if args.plot:
                p = Popen(
                    'python maml_rl/utils/visualize.py --log-dir {}'.format(
                        args.log_dir),
                    shell=True)
                processes.append(p)

        logger.dumpkvs()
Beispiel #4
0
def train():
    processes = []
    if os.path.isdir(args.log_dir):
        ans = input('{} exists\ncontinue and overwrite? y/n: '.format(
            args.log_dir))
        if ans == 'n':
            return

    logger.configure(dir=args.log_dir, format_strs=['stdout', 'log', 'csv'])
    logger.log(args)
    json.dump(vars(args), open(os.path.join(args.log_dir, 'params.json'), 'w'))

    torch.set_num_threads(2)

    start = time.time()
    policy_update_time, policy_forward_time = 0, 0
    step_time_env, step_time_total, step_time_rewarder = 0, 0, 0
    visualize_time = 0
    rewarder_fit_time = 0

    envs = ContextualEnvInterface(args)
    if args.look:
        looker = Looker(args.log_dir)

    actor_critic, agent = initialize_policy(envs)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                              envs.obs_shape, envs.action_space,
                              actor_critic.recurrent_hidden_state_size)
    rollouts.to(args.device)

    def copy_obs_into_beginning_of_storage(obs):
        rollouts.obs[0].copy_(obs)

    for j in range(args.num_updates):

        obs = envs.reset(
        )  # have to reset here to use updated rewarder to sample tasks
        copy_obs_into_beginning_of_storage(obs)

        if args.use_linear_lr_decay:
            update_linear_schedule(agent.optimizer, j, args.num_updates,
                                   args.lr)

        if args.algo == 'ppo' and args.use_linear_clip_decay:
            agent.clip_param = args.clip_param * (1 -
                                                  j / float(args.num_updates))

        log_marginal = 0
        lambda_log_s_given_z = 0

        for step in range(args.num_steps):
            # Sample actions
            policy_forward_start = time.time()
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],
                    rollouts.masks[step])
            policy_forward_time += time.time() - policy_forward_start

            # Obser reward and next obs
            step_total_start = time.time()
            obs, reward, done, info = envs.step(action)
            step_time_total += time.time() - step_total_start
            step_time_env += info['step_time_env']
            step_time_rewarder += info['reward_time']
            if args.rewarder == 'unsupervised' and args.clusterer == 'vae':
                log_marginal += info['log_marginal'].sum().item()
                lambda_log_s_given_z += info['lambda_log_s_given_z'].sum(
                ).item()

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0]
                                       for done_ in done])
            rollouts.insert(obs, recurrent_hidden_states, action,
                            action_log_prob, value, reward, masks)

        assert all(done)

        # policy update
        with torch.no_grad():
            next_value = actor_critic.get_value(
                rollouts.obs[-1], rollouts.recurrent_hidden_states[-1],
                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma,
                                 args.tau)

        policy_update_start = time.time()
        if args.rewarder != 'supervised' and envs.rewarder.fit_counter == 0:
            value_loss, action_loss, dist_entropy = 0, 0, 0
        else:
            value_loss, action_loss, dist_entropy = agent.update(rollouts)
        policy_update_time += time.time() - policy_update_start
        rollouts.after_update()

        # metrics
        trajectories = envs.trajectories_current_update
        state_entropy = calculate_state_entropy(args, trajectories)

        return_avg = rollouts.rewards.sum() / args.trials_per_update
        reward_avg = return_avg / (args.trial_length * args.episode_length)
        log_marginal_avg = log_marginal / args.trials_per_update / (
            args.trial_length * args.episode_length)
        lambda_log_s_given_z_avg = lambda_log_s_given_z / args.trials_per_update / (
            args.trial_length * args.episode_length)

        num_steps = (j + 1) * args.num_steps * args.num_processes
        num_episodes = num_steps // args.episode_length
        num_trials = num_episodes // args.trial_length

        logger.logkv('state_entropy', state_entropy)
        logger.logkv('value_loss', value_loss)
        logger.logkv('action_loss', action_loss)
        logger.logkv('dist_entropy', dist_entropy)
        logger.logkv('return_avg', return_avg.item())
        logger.logkv('reward_avg', reward_avg.item())
        logger.logkv('steps', num_steps)
        logger.logkv('episodes', num_episodes)
        logger.logkv('trials', num_trials)
        logger.logkv('policy_updates', (j + 1))
        logger.logkv('time', time.time() - start)
        logger.logkv('policy_forward_time', policy_forward_time)
        logger.logkv('policy_update_time', policy_update_time)
        logger.logkv('step_time_rewarder', step_time_rewarder)
        logger.logkv('step_time_env', step_time_env)
        logger.logkv('step_time_total', step_time_total)
        logger.logkv('visualize_time', visualize_time)
        logger.logkv('rewarder_fit_time', rewarder_fit_time)
        if args.rewarder == 'unsupervised' and args.clusterer == 'vae':
            logger.logkv('log_marginal_avg', log_marginal_avg)
            logger.logkv('lambda_log_s_given_z_avg', lambda_log_s_given_z_avg)
        logger.dumpkvs()

        if (j % args.save_period == 0
                or j == args.num_updates - 1) and args.log_dir != '':
            save_model(args, actor_critic, envs, iteration=j)

        if j % args.rewarder_fit_period == 0:
            rewarder_fit_start = time.time()
            envs.fit_rewarder()
            rewarder_fit_time += time.time() - rewarder_fit_start

        if (j % args.vis_period == 0
                or j == args.num_updates - 1) and args.log_dir != '':
            visualize_start = time.time()
            if args.look:
                looker.look(iteration=j)
            if args.plot:
                p = Popen('python visualize.py --log-dir {}'.format(
                    args.log_dir),
                          shell=True)
                processes.append(p)
            visualize_time += time.time() - visualize_start
Beispiel #5
0
def train():
    processes = []
    if os.path.isdir(args.log_dir):
        ans = input('{} exists\ncontinue and overwrite? y/n: '.format(args.log_dir))
        if ans == 'n':
            return

    logger.configure(dir=args.log_dir, format_strs=['stdout', 'log', 'csv'])
    logger.log(args)
    json.dump(vars(args), open(os.path.join(args.log_dir, 'params.json'), 'w'))

    torch.set_num_threads(2)

    start = time.time()
    policy_update_time, policy_forward_time = 0, 0
    step_time_env, step_time_total, step_time_rewarder = 0, 0, 0
    visualize_time = 0
    rewarder_fit_time = 0

    envs = RL2EnvInterface(args)
    if args.look:
        looker = Looker(args.log_dir)

    actor_critic = Policy(envs.obs_shape, envs.action_space,
                          base=RL2Base, base_kwargs={'recurrent': True,
                                                     'num_act_dim': envs.action_space.shape[0]})
    actor_critic.to(args.device)
    agent = algo.PPO(actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch,
                     args.value_loss_coef, args.entropy_coef, lr=args.lr,
                     eps=args.eps,
                     max_grad_norm=args.max_grad_norm)

    rollouts = RolloutStorage(args.num_steps, args.num_processes,
                        envs.obs_shape, envs.action_space,
                        actor_critic.recurrent_hidden_state_size)
    rollouts.to(args.device)

    def copy_obs_into_beginning_of_storage(obs):
        obs_raw, obs_act, obs_rew, obs_flag = obs
        rollouts.obs[0].copy_(obs_raw)
        rollouts.obs_act[0].copy_(obs_act)
        rollouts.obs_rew[0].copy_(obs_rew)
        rollouts.obs_flag[0].copy_(obs_flag)

    for j in range(args.num_updates):
        obs = envs.reset()
        copy_obs_into_beginning_of_storage(obs)

        if args.use_linear_lr_decay:
            update_linear_schedule(agent.optimizer, j, args.num_updates, args.lr)

        if args.algo == 'ppo' and args.use_linear_clip_decay:
            agent.clip_param = args.clip_param  * (1 - j / float(args.num_updates))

        episode_returns = [0 for i in range(args.trial_length)]
        episode_final_reward = [0 for i in range(args.trial_length)]
        i_episode = 0

        log_marginal = 0
        lambda_log_s_given_z = 0

        for step in range(args.num_steps):
            # Sample actions
            policy_forward_start = time.time()
            with torch.no_grad():
                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(
                        rollouts.get_obs(step),
                        rollouts.recurrent_hidden_states[step],
                        rollouts.masks[step])
            policy_forward_time += time.time() - policy_forward_start

            # Obser reward and next obs
            step_total_start = time.time()
            obs, reward, done, info = envs.step(action)
            step_time_total += time.time() - step_total_start
            step_time_env += info['step_time_env']
            step_time_rewarder += info['reward_time']
            log_marginal += info['log_marginal'].sum().item()
            lambda_log_s_given_z += info['lambda_log_s_given_z'].sum().item()

            episode_returns[i_episode] += reward.sum().item()
            if all(done['episode']):
                episode_final_reward[i_episode] += reward.sum().item()
                i_episode = (i_episode + 1) % args.trial_length

            # If done then clean the history of observations.
            masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done['trial']])
            rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks)

        assert all(done['trial'])

        with torch.no_grad():
            next_value = actor_critic.get_value(rollouts.get_obs(-1),
                                                rollouts.recurrent_hidden_states[-1],
                                                rollouts.masks[-1]).detach()

        rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.tau)

        policy_update_start = time.time()
        if args.rewarder != 'supervised' and envs.rewarder.fit_counter == 0 and not args.vae_load:
            value_loss, action_loss, dist_entropy = 0, 0, 0
        else:
            value_loss, action_loss, dist_entropy = agent.update(rollouts)
        policy_update_time += time.time() - policy_update_start
        rollouts.after_update()

        # metrics
        trajectories_pre = envs.trajectories_pre_current_update
        state_entropy_pre = calculate_state_entropy(args, trajectories_pre)

        trajectories_post = envs.trajectories_post_current_update
        state_entropy_post = calculate_state_entropy(args, trajectories_post)

        return_avg = rollouts.rewards.sum() / args.trials_per_update
        reward_avg = return_avg / (args.trial_length * args.episode_length)
        log_marginal_avg = log_marginal / args.trials_per_update / (args.trial_length * args.episode_length)
        lambda_log_s_given_z_avg = lambda_log_s_given_z / args.trials_per_update / (args.trial_length * args.episode_length)

        num_steps = (j + 1) * args.num_steps * args.num_processes
        num_episodes = num_steps // args.episode_length
        num_trials = num_episodes // args.trial_length

        logger.logkv('state_entropy_pre', state_entropy_pre)
        logger.logkv('state_entropy_post', state_entropy_post)
        logger.logkv('value_loss', value_loss)
        logger.logkv('action_loss', action_loss)
        logger.logkv('dist_entropy', dist_entropy)
        logger.logkv('return_avg', return_avg.item())
        logger.logkv('reward_avg', reward_avg.item())
        logger.logkv('steps', (j + 1) * args.num_steps * args.num_processes)
        logger.logkv('episodes', num_episodes)
        logger.logkv('trials', num_trials)
        logger.logkv('policy_updates', (j + 1))
        logger.logkv('time', time.time() - start)
        logger.logkv('policy_forward_time', policy_forward_time)
        logger.logkv('policy_update_time', policy_update_time)
        logger.logkv('step_time_rewarder', step_time_rewarder)
        logger.logkv('step_time_env', step_time_env)
        logger.logkv('step_time_total', step_time_total)
        logger.logkv('visualize_time', visualize_time)
        logger.logkv('rewarder_fit_time', rewarder_fit_time)
        logger.logkv('log_marginal_avg', log_marginal_avg)
        logger.logkv('lambda_log_s_given_z_avg', lambda_log_s_given_z_avg)
        for i_episode in range(args.trial_length):
            logger.logkv('episode_return_avg_{}'.format(i_episode),
                         episode_returns[i_episode] / args.trials_per_update)
            logger.logkv('episode_final_reward_{}'.format(i_episode),
                         episode_final_reward[i_episode] / args.trials_per_update)

        if (j % args.save_period == 0 or j == args.num_updates - 1) and args.log_dir != '':
            save_model(args, actor_critic, envs, iteration=j)

        if not args.vae_freeze and j % args.rewarder_fit_period == 0:
            rewarder_fit_start = time.time()
            envs.fit_rewarder()
            rewarder_fit_time += time.time() - rewarder_fit_start

        if (j % args.vis_period == 0 or j == args.num_updates - 1) and args.log_dir != '':
            visualize_start = time.time()
            if args.look:
                eval_return_avg, eval_episode_returns, eval_episode_final_reward = looker.look(iteration=j)
                logger.logkv('eval_return_avg', eval_return_avg)
                for i_episode in range(args.trial_length):
                    logger.logkv('eval_episode_return_avg_{}'.format(i_episode),
                                 eval_episode_returns[i_episode] / args.trials_per_update)
                    logger.logkv('eval_episode_final_reward_{}'.format(i_episode),
                                 eval_episode_final_reward[i_episode] / args.trials_per_update)

            if args.plot:
                p = Popen('python visualize.py --log-dir {}'.format(args.log_dir), shell=True)
                processes.append(p)
            visualize_time += time.time() - visualize_start

        logger.dumpkvs()
Beispiel #6
0
def fit(
        policy,
        env,
        nsteps,
        total_timesteps,
        ent_coef,
        lr,
        vf_coef=0.5,
        max_grad_norm=0.5,
        gamma=0.99,
        lam=0.95,
        log_interval=10,
        nminibatches=4,
        noptepochs=4,
        cliprange=0.2,
        save_interval=0,
        load_path=None
):

    if isinstance(lr, float):
        lr = constfn(lr)
    else:
        assert callable(lr)
    if isinstance(cliprange, float):
        cliprange = constfn(cliprange)
    else:
        assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    nenvs = env.num_envs
    # nenvs = 8
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches

    model = PPO2(
        policy=policy,
        observation_space=ob_space,
        action_space=ac_space,
        nbatch_act=nenvs,
        nbatch_train=nbatch_train,
        nsteps=nsteps,
        ent_coef=ent_coef,
        vf_coef=vf_coef,
        max_grad_norm=max_grad_norm
    )
    Agent().init_vars()
    # if save_interval and logger.get_dir():
    #     import cloudpickle
    #     with open(os.path.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
    #         fh.write(cloudpickle.dumps(make_model))
    # model = make_model()
    # if load_path is not None:
    #     model.load(load_path)
    runner = Environment(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)

    epinfobuf = deque(maxlen=100)
    tfirststart = time.time()

    nupdates = total_timesteps//nbatch
    for update in range(1, nupdates+1):
        assert nbatch % nminibatches == 0
        nbatch_train = nbatch // nminibatches
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)
        obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run()
        epinfobuf.extend(epinfos)
        mblossvals = []
        if states is None:  # nonrecurrent version
            inds = np.arange(nbatch)
            for _ in range(noptepochs):
                np.random.shuffle(inds)
                for start in range(0, nbatch, nbatch_train):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (
                        arr[mbinds] for arr in (obs,
                                                returns,
                                                masks,
                                                actions,
                                                values,
                                                neglogpacs)
                    )
                    mblossvals.append(model.predict(lrnow, cliprangenow, *slices))
        else:  # recurrent version
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            envsperbatch = nbatch_train // nsteps
            for _ in range(noptepochs):
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = (
                        arr[mbflatinds] for arr in (obs,
                                                    returns,
                                                    masks,
                                                    actions,
                                                    values,
                                                    neglogpacs)
                    )
                    mbstates = states[mbenvinds]
                    mblossvals.append(model.predict(lrnow, cliprangenow, *slices, mbstates))

        lossvals = np.mean(mblossvals, axis=0)
        tnow = time.time()
        fps = int(nbatch / (tnow - tstart))
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(values, returns)
            logger.logkv("serial_timesteps", update*nsteps)
            logger.logkv("nupdates", update)
            logger.logkv("total_timesteps", update*nbatch)
            logger.logkv("fps", fps)
            logger.logkv("explained_variance", float(ev))
            logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
            logger.logkv('time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv(lossname, lossval)
            logger.dumpkvs()
        if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
            checkdir = os.path.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = os.path.join(checkdir, '%.5i' % update)
            print('Saving to', savepath)
            model.save(savepath)
    env.close()
    return model
Beispiel #7
0
    def train(self):
        epinfobuf = deque(maxlen=10)
        t_trainstart = time.time()
        for iter in range(self.global_iter, self.max_iters):
            self.global_iter = iter
            t_iterstart = time.time()
            if iter % self.config['save_interval'] == 0 and logger.get_dir():
                with torch.no_grad():
                    res = self.rollout(val=True)
                    obs, actions, returns, values, advs, log_probs, epinfos = res
                avg_reward = safemean([epinfo['reward'] for epinfo in epinfos])
                avg_area = safemean(
                    [epinfo['seen_area'] for epinfo in epinfos])
                logger.logkv("iter", iter)
                logger.logkv("test/total_timesteps", iter * self.nbatch)
                logger.logkv('test/avg_area', avg_area)
                logger.logkv('test/avg_reward', avg_reward)
                logger.dumpkvs()
                if avg_reward > self.best_avg_reward:
                    self.best_avg_reward = avg_reward
                    is_best = True
                else:
                    is_best = False
                self.save_model(is_best=is_best, step=iter)
            with torch.no_grad():
                obs, actions, returns, values, advs, log_probs, epinfos = self.n_rollout(
                    repeat_num=self.config['train_rollout_repeat'])
            if epinfos:
                epinfobuf.extend(epinfos)
            lossvals = {
                'policy_loss': [],
                'value_loss': [],
                'policy_entropy': [],
                'approxkl': [],
                'clipfrac': []
            }
            opt_start_t = time.time()
            noptepochs = self.noptepochs
            for _ in range(noptepochs):
                num_batches = int(
                    np.ceil(actions.shape[1] / self.config['batch_size']))
                for x in range(num_batches):
                    b_start = x * self.config['batch_size']
                    b_end = min(b_start + self.config['batch_size'],
                                actions.shape[1])
                    if self.config['use_rgb_with_map']:
                        rgbs, large_maps, small_maps = obs
                        b_rgbs, b_large_maps, b_small_maps = map(
                            lambda p: p[:, b_start:b_end],
                            (rgbs, large_maps, small_maps))
                    else:
                        large_maps, small_maps = obs
                        b_large_maps, b_small_maps = map(
                            lambda p: p[:, b_start:b_end],
                            (large_maps, small_maps))
                    b_actions, b_returns, b_advs, b_log_probs = map(
                        lambda p: p[:, b_start:b_end],
                        (actions, returns, advs, log_probs))
                    hidden_state = self.net_model.init_hidden(
                        batch_size=b_end - b_start)
                    for start in range(0, actions.shape[0], self.rnn_seq_len):
                        end = start + self.rnn_seq_len
                        slices = (arr[start:end]
                                  for arr in (b_large_maps, b_small_maps,
                                              b_actions, b_returns, b_advs,
                                              b_log_probs))

                        if self.config['use_rgb_with_map']:
                            info, hidden_state = self.net_train(
                                *slices,
                                hidden_state=hidden_state,
                                rgbs=b_rgbs[start:end])
                        else:
                            info, hidden_state = self.net_train(
                                *slices, hidden_state=hidden_state)
                        lossvals['policy_loss'].append(info['pg_loss'])
                        lossvals['value_loss'].append(info['vf_loss'])
                        lossvals['policy_entropy'].append(info['entropy'])
                        lossvals['approxkl'].append(info['approxkl'])
                        lossvals['clipfrac'].append(info['clipfrac'])
            tnow = time.time()
            int_t_per_epo = (tnow - opt_start_t) / float(self.noptepochs)
            print_cyan(
                'Net training time per epoch: {0:.4f}s'.format(int_t_per_epo))
            fps = int(self.nbatch / (tnow - t_iterstart))
            if iter % self.config['log_interval'] == 0:
                logger.logkv("Learning rate",
                             self.optimizer.param_groups[0]['lr'])
                logger.logkv("per_env_timesteps", iter * self.num_steps)
                logger.logkv("iter", iter)
                logger.logkv("total_timesteps", iter * self.nbatch)
                logger.logkv("fps", fps)
                logger.logkv(
                    'ep_rew_mean',
                    safemean([epinfo['reward'] for epinfo in epinfobuf]))
                logger.logkv(
                    'ep_area_mean',
                    safemean([epinfo['seen_area'] for epinfo in epinfobuf]))
                logger.logkv('time_elapsed', tnow - t_trainstart)
                for name, value in lossvals.items():
                    logger.logkv(name, np.mean(value))
                logger.dumpkvs()
Beispiel #8
0
 def _logger(self, keys, strs):
     values = self.sess.run(keys)
     for s, v in zip(strs, values):
         logger.logkv(s, v)
     logger.dumpkvs()
Beispiel #9
0
def main(args):
    continuous_actions = (args.env_name in [
        'AntVel-v1', 'AntDir-v1', 'AntPos-v0', 'HalfCheetahVel-v1',
        'HalfCheetahDir-v1', '2DNavigation-v0', 'Point2DWalls-corner-v0',
        'Ant-v0', 'HalfCheetah-v0'
    ])

    writer = SummaryWriter(log_dir=args.log_dir)

    logger.configure(dir=args.log_dir, format_strs=['stdout', 'log', 'csv'])
    logger.log(args)
    json.dump(vars(args),
              open(os.path.join(
                  args.log_dir,
                  'params.json',
              ), 'w'),
              indent=2)

    sampler = BatchSampler(args.env_name,
                           batch_size=args.fast_batch_size,
                           num_workers=args.num_workers)
    if continuous_actions:
        policy = NormalMLPPolicy(
            int(np.prod(sampler.envs.observation_space.shape)),
            int(np.prod(sampler.envs.action_space.shape)),
            hidden_sizes=(args.hidden_size, ) * args.num_layers)
    else:
        policy = CategoricalMLPPolicy(
            int(np.prod(sampler.envs.observation_space.shape)),
            sampler.envs.action_space.n,
            hidden_sizes=(args.hidden_size, ) * args.num_layers)
    baseline = LinearFeatureBaseline(
        int(np.prod(sampler.envs.observation_space.shape)))

    metalearner = MetaLearner(sampler,
                              policy,
                              baseline,
                              gamma=args.gamma,
                              fast_lr=args.fast_lr,
                              tau=args.tau,
                              device=args.device)

    for batch in range(args.num_batches):
        tasks = sampler.sample_tasks(num_tasks=args.meta_batch_size)
        episodes = metalearner.sample(tasks, first_order=args.first_order)
        metalearner.step(episodes,
                         max_kl=args.max_kl,
                         cg_iters=args.cg_iters,
                         cg_damping=args.cg_damping,
                         ls_max_steps=args.ls_max_steps,
                         ls_backtrack_ratio=args.ls_backtrack_ratio)

        # # Tensorboard
        writer.add_scalar('total_rewards/before_update',
                          total_rewards([ep.rewards for ep, _ in episodes]),
                          batch)
        writer.add_scalar('total_rewards/after_update',
                          total_rewards([ep.rewards for _, ep in episodes]),
                          batch)

        logger.logkv('return_avg_pre',
                     total_rewards([ep.rewards for ep, _ in episodes]))
        logger.logkv('return_avg_post',
                     total_rewards([ep.rewards for _, ep in episodes]))
        logger.dumpkvs()