Example #1
0
def test_mpi_weighted_mean():
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    with logger.scoped_configure(comm=comm):
        if comm.rank == 0:
            name2valcount = {'a' : (10, 2), 'b' : (20,3)}
        elif comm.rank == 1:
            name2valcount = {'a' : (19, 1), 'c' : (42,3)}
        else:
            raise NotImplementedError

        d = mpi_util.mpi_weighted_mean(comm, name2valcount)
        correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
        if comm.rank == 0:
            assert d == correctval, '{} != {}'.format(d, correctval)

        for name, (val, count) in name2valcount.items():
            for _ in range(count):
                logger.logkv_mean(name, val)
        d2 = logger.dumpkvs()
        if comm.rank == 0:
            assert d2 == correctval
Example #2
0
def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
            vf_coef=0.5,  max_grad_norm=0.5, gamma=0.99, lam=0.95,
            log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
            save_interval=0, load_path=None, model_fn=None, **network_kwargs):
    '''
    Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)

    Parameters:
    ----------

    network:                          policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
                                      specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
                                      tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
                                      neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
                                      See common/models.py/lstm for more details on using recurrent nets in policies

    env: baselines.common.vec_env.VecEnv     environment. Needs to be vectorized for parallel environment simulation.
                                      The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class.


    nsteps: int                       number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
                                      nenv is number of environment copies simulated in parallel)

    total_timesteps: int              number of timesteps (i.e. number of actions taken in the environment)

    ent_coef: float                   policy entropy coefficient in the optimization objective

    lr: float or function             learning rate, constant or a schedule function [0,1] -> R+ where 1 is beginning of the
                                      training and 0 is the end of the training.

    vf_coef: float                    value function loss coefficient in the optimization objective

    max_grad_norm: float or None      gradient norm clipping coefficient

    gamma: float                      discounting factor

    lam: float                        advantage estimation discounting factor (lambda in the paper)

    log_interval: int                 number of timesteps between logging events

    nminibatches: int                 number of training minibatches per update. For recurrent policies,
                                      should be smaller or equal than number of environments run in parallel.

    noptepochs: int                   number of training epochs per update

    cliprange: float or function      clipping range, constant or schedule function [0,1] -> R+ where 1 is beginning of the training
                                      and 0 is the end of the training

    save_interval: int                number of timesteps between saving events

    load_path: str                    path to load the model from

    **network_kwargs:                 keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
                                      For instance, 'mlp' network architecture has arguments num_hidden and num_layers.



    '''

    set_global_seeds(seed)

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    policy = build_policy(env, network, **network_kwargs)

    # Get the nb of env
    nenvs = env.num_envs

    # Get state_space and action_space
    ob_space = env.observation_space
    ac_space = env.action_space

    # Calculate the batch_size
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches

    # Instantiate the model object (that creates act_model and train_model)
    if model_fn is None:
        from baselines.ppo2.model import Model
        model_fn = Model

    model = model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
                    nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
                    max_grad_norm=max_grad_norm)

    if load_path is not None:
        model.load(load_path)
    # Instantiate the runner object
    runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)
    if eval_env is not None:
        eval_runner = Runner(env = eval_env, model = model, nsteps = nsteps, gamma = gamma, lam= lam)

    epinfobuf = deque(maxlen=100)
    if eval_env is not None:
        eval_epinfobuf = deque(maxlen=100)

    # Start total timer
    tfirststart = time.perf_counter()

    nupdates = total_timesteps//nbatch
    for update in range(1, nupdates+1):
        assert nbatch % nminibatches == 0
        # Start timer
        tstart = time.perf_counter()
        frac = 1.0 - (update - 1.0) / nupdates
        # Calculate the learning rate
        lrnow = lr(frac)
        # Calculate the cliprange
        cliprangenow = cliprange(frac)
        # Get minibatch
        obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632
        if eval_env is not None:
            eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run() #pylint: disable=E0632

        epinfobuf.extend(epinfos)
        if eval_env is not None:
            eval_epinfobuf.extend(eval_epinfos)

        # Here what we're going to do is for each minibatch calculate the loss and append it.
        mblossvals = []
        if states is None: # nonrecurrent version
            # Index of each element of batch_size
            # Create the indices array
            inds = np.arange(nbatch)
            for _ in range(noptepochs):
                # Randomize the indexes
                np.random.shuffle(inds)
                # 0 to batch_size with batch_train_size step
                for start in range(0, nbatch, nbatch_train):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
                    mblossvals.append(model.train(lrnow, cliprangenow, *slices))
        else: # recurrent version
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            for _ in range(noptepochs):
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = (arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
                    mbstates = states[mbenvinds]
                    mblossvals.append(model.train(lrnow, cliprangenow, *slices, mbstates))

        # Feedforward --> get losses --> update
        lossvals = np.mean(mblossvals, axis=0)
        # End timer
        tnow = time.perf_counter()
        # Calculate the fps (frame per second)
        fps = int(nbatch / (tnow - tstart))
        if update % log_interval == 0 or update == 1:
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = explained_variance(values, returns)
            logger.logkv("serial_timesteps", update*nsteps)
            logger.logkv("nupdates", update)
            logger.logkv("total_timesteps", update*nbatch)
            logger.logkv("fps", fps)
            logger.logkv("explained_variance", float(ev))
            logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
            if eval_env is not None:
                logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf]) )
                logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf]) )
            logger.logkv('time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv(lossname, lossval)
            if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
                logger.dumpkvs()
        if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and (MPI is None or MPI.COMM_WORLD.Get_rank() == 0):
            checkdir = osp.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = osp.join(checkdir, '%.5i'%update)
            print('Saving to', savepath)
            model.save(savepath)
    return model
def learn(*, policy, env, nsteps, total_timesteps, ent_coef, lr,
            vf_coef=0.5,  max_grad_norm=0.5, gamma=0.99, lam=0.95,
            log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
            save_interval=0, load_path=None, num_casks=0):

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    nenvs = env.num_envs - num_casks
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches

    make_model = lambda : Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=env.num_envs, nbatch_train=nbatch_train,
                    nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
                    max_grad_norm=max_grad_norm)
    if save_interval and logger.get_dir():
        import cloudpickle
        with open(osp.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
            fh.write(cloudpickle.dumps(make_model))
    model = make_model()
    if load_path is not None:
        model.load(load_path)
        # load running mean std
        checkdir = load_path[0:-5]
        checkpoint = int(load_path.split('/')[-1])
        if osp.exists(osp.join(checkdir, '%.5i_ob_rms.pkl' % checkpoint)):
            with open(osp.join(checkdir, '%.5i_ob_rms.pkl' % checkpoint), 'rb') as ob_rms_fp:
                env.ob_rms = pickle.load(ob_rms_fp)
        # if osp.exists(osp.join(checkdir, '%.5i_ret_rms.pkl' % checkpoint)):
        #     with open(osp.join(checkdir, '%.5i_ret_rms.pkl' % checkpoint), 'rb') as ret_rms_fp:
        #         env.ret_rms = pickle.load(ret_rms_fp)
    # tensorboard
    writer = tf.summary.FileWriter(logger.get_dir(), tf.get_default_session().graph)
    runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam, writer=writer, num_casks=num_casks)

    epinfobuf = deque(maxlen=100)
    tfirststart = time.time()

    nupdates = total_timesteps//nbatch
    for update in range(1, nupdates+1):
        assert nbatch % nminibatches == 0
        nbatch_train = nbatch // nminibatches
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)
        obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632
        epinfobuf.extend(epinfos)
        mblossvals = []
        if states is None: # nonrecurrent version
            inds = np.arange(nbatch)
            for _ in range(noptepochs):
                np.random.shuffle(inds)
                for start in range(0, nbatch, nbatch_train):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
                    mblossvals.append(model.train(lrnow, cliprangenow, *slices))
        else: # recurrent version
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            envsperbatch = nbatch_train // nsteps
            for _ in range(noptepochs):
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = (arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
                    mbstates = states[mbenvinds]
                    mblossvals.append(model.train(lrnow, cliprangenow, *slices, mbstates))

        lossvals = np.mean(mblossvals, axis=0)
        tnow = time.time()
        fps = int(nbatch / (tnow - tstart))
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(values, returns)
            logger.logkv("serial_timesteps", update*nsteps)
            logger.logkv("nupdates", update)
            logger.logkv("total_timesteps", update*nbatch)
            logger.logkv("fps", fps)
            logger.logkv("explained_variance", float(ev))
            logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
            logger.logkv('epsrewmean', safemean([epinfo['sr'] for epinfo in epinfobuf]))
            logger.logkv('time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv(lossname, lossval)
            logger.dumpkvs()
            # tensorboard
            summary = tf.Summary()
            summary.value.add(tag='iteration/reward_mean', simple_value=safemean([epinfo['r'] for epinfo in epinfobuf]))
            summary.value.add(tag='iteration/length_mean', simple_value=safemean([epinfo['l'] for epinfo in epinfobuf]))
            summary.value.add(tag='iteration/shaped_reward_mean', simple_value=safemean([epinfo['sr'] for epinfo in epinfobuf]))
            summary.value.add(tag='iteration/fps', simple_value=fps)
            writer.add_summary(summary, update)
        if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
            checkdir = osp.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = osp.join(checkdir, '%.5i'%update)
            print('Saving to', savepath)
            model.save(savepath)
            # save running mean std
            with open(osp.join(checkdir, '%.5i_ob_rms.pkl' % update), 'wb') as ob_rms_fp:
                pickle.dump(env.ob_rms, ob_rms_fp)
            with open(osp.join(checkdir, '%.5i_ret_rms.pkl' % update), 'wb') as ret_rms_fp:
                pickle.dump(env.ret_rms, ret_rms_fp)
    env.close()
def learn(*,
          policy,
          env,
          nsteps,
          total_timesteps,
          ent_coef,
          lr,
          vf_coef=0.5,
          max_grad_norm=0.5,
          gamma=0.99,
          lam=0.95,
          log_interval=100,
          nminibatches=4,
          noptepochs=4,
          cliprange=0.2,
          save_interval=10000):

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches

    make_model = lambda: Model(policy=policy,
                               ob_space=ob_space,
                               ac_space=ac_space,
                               nbatch_act=nenvs,
                               nbatch_train=nbatch_train,
                               nsteps=nsteps,
                               ent_coef=ent_coef,
                               vf_coef=vf_coef,
                               max_grad_norm=max_grad_norm)
    if save_interval and logger.get_dir():
        import cloudpickle
        with open(osp.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
            fh.write(cloudpickle.dumps(make_model))
    model = make_model()
    runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)

    epinfobuf = deque(maxlen=100)
    tfirststart = time.time()

    nupdates = total_timesteps // nbatch
    for update in range(1, nupdates + 1):
        assert nbatch % nminibatches == 0
        nbatch_train = nbatch // nminibatches
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)
        obs, returns, returns_square, masks, actions, values, moments, neglogpacs, states, epinfos = runner.run(
        )  #pylint: disable=E0632
        epinfobuf.extend(epinfos)
        mblossvals = []
        if states is None:  # nonrecurrent version
            inds = np.arange(nbatch)
            for _ in range(noptepochs):
                np.random.shuffle(inds)
                for start in range(0, nbatch, nbatch_train):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (arr[mbinds]
                              for arr in (obs, returns, returns_square, masks,
                                          actions, values, moments,
                                          neglogpacs))
                    mblossvals.append(model.train(lrnow, cliprangenow,
                                                  *slices))
        else:  # recurrent version
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            envsperbatch = nbatch_train // nsteps
            for _ in range(noptepochs):
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = (arr[mbflatinds]
                              for arr in (obs, returns, masks, actions, values,
                                          neglogpacs))
                    mbstates = states[mbenvinds]
                    mblossvals.append(
                        model.train(lrnow, cliprangenow, *slices, mbstates))

        lossvals = np.mean(mblossvals, axis=0)
        tnow = time.time()
        fps = int(nbatch / (tnow - tstart))
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(values, returns)
            em = explained_variance(moments, returns_square)

            logger.logkv("serial_timesteps", update * nsteps)
            logger.logkv("nupdates", update)
            logger.logkv("total_timesteps", update * nbatch)
            logger.logkv("fps", fps)
            logger.logkv("explained_variance_rewards", float(ev))
            logger.logkv("explained_variance_moments", float(em))
            logger.logkv('episode_reward_mean',
                         safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean',
                         safemean([epinfo['l'] for epinfo in epinfobuf]))
            logger.logkv('time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv(lossname, lossval)
            logger.dumpkvs()

        if save_interval and (update % save_interval == 0
                              or update == 1) and logger.get_dir():
            checkdir = osp.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = osp.join(checkdir, '%.5i' % update)
            print('Saving to', savepath)
            model.save(savepath)
    if logger.get_dir():
        checkdir = osp.join(logger.get_dir(), 'checkpoints')
        os.makedirs(checkdir, exist_ok=True)
        savepath = osp.join(checkdir, '%.5i' % update)
        print('Saving to', savepath)
        model.save(savepath)
    env.close()
def learn(*,
          network,
          env,
          total_timesteps,
          dtarg=0.01,
          adaptive_kl=0,
          trunc_rho=1.0,
          clipcut=0.2,
          useadv=0,
          vtrace=0,
          rgae=0,
          eval_env=None,
          seed=None,
          ERlen=1,
          nsteps=2048,
          ent_coef=0.0,
          lr=3e-4,
          vf_coef=0.5,
          max_grad_norm=None,
          gamma=0.99,
          lam=0.95,
          log_interval=10,
          nminibatches=4,
          noptepochs=4,
          cliprange=0.2,
          save_interval=0,
          load_path=None,
          **network_kwargs):
    '''
    Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)

    Parameters:
    ----------

    network:                          policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
                                      specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
                                      tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
                                      neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
                                      See common/models.py/lstm for more details on using recurrent nets in policies

    env: baselines.common.vec_env.VecEnv     environment. Needs to be vectorized for parallel environment simulation.
                                      The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class.


    nsteps: int                       number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
                                      nenv is number of environment copies simulated in parallel)

    total_timesteps: int              number of timesteps (i.e. number of actions taken in the environment)

    ent_coef: float                   policy entropy coefficient in the optimization objective

    lr: float or function             learning rate, constant or a schedule function [0,1] -> R+ where 1 is beginning of the
                                      training and 0 is the end of the training.

    vf_coef: float                    value function loss coefficient in the optimization objective

    max_grad_norm: float or None      gradient norm clipping coefficient

    gamma: float                      discounting factor

    lam: float                        advantage estimation discounting factor (lambda in the paper)

    log_interval: int                 number of timesteps between logging events

    nminibatches: int                 number of training minibatches per update. For recurrent policies,
                                      should be smaller or equal than number of environments run in parallel.

    noptepochs: int                   number of training epochs per update

    cliprange: float or function      clipping range, constant or schedule function [0,1] -> R+ where 1 is beginning of the training
                                      and 0 is the end of the training

    save_interval: int                number of timesteps between saving events

    load_path: str                    path to load the model from

    **network_kwargs:                 keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
                                      For instance, 'mlp' network architecture has arguments num_hidden and num_layers.



    '''

    set_global_seeds(seed)

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    policy = build_policy(env, network, **network_kwargs)

    # Get the nb of env
    nenvs = env.num_envs

    # Get state_space and action_space
    ob_space = env.observation_space
    ac_space = env.action_space
    acdim = ac_space.shape[0]

    # Calculate the batch_size
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches

    # Instantiate the model object (that creates act_model and train_model)
    make_model = lambda: Model(policy=policy,
                               ob_space=ob_space,
                               ac_space=ac_space,
                               nbatch_act=nenvs,
                               nbatch_train=nbatch_train,
                               nsteps=nsteps,
                               ent_coef=ent_coef,
                               vf_coef=vf_coef,
                               max_grad_norm=max_grad_norm,
                               adaptive_kl=adaptive_kl)
    model = make_model()
    if load_path is not None:
        model.load(load_path)
    # Instantiate the runner object
    runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)
    if eval_env is not None:
        eval_runner = EvalRunner(env=eval_env,
                                 model=model,
                                 nsteps=10 * nsteps,
                                 gamma=gamma,
                                 lam=lam)
        eval_runner.obfilt = runner.obfilt
        eval_runner.rewfilt = runner.rewfilt

    epinfobuf = deque(maxlen=10)
    if eval_env is not None:
        eval_epinfobuf = deque(maxlen=10)

    # Start total timer
    tfirststart = time.time()

    nupdates = total_timesteps // nbatch

    def add_vtarg_and_adv(seg, gamma, value, lam):
        """
        Compute target value using TD(lambda) estimator, and advantage with GAE(lambda)
        """
        done = np.append(
            seg["done"], 0
        )  # last element is only used for last vtarg, but we already zeroed it if last new = 1

        T = len(seg["rew"])
        gaelam = np.empty(T, 'float32')
        rew = runner.rewfilt(seg["rew"])
        lastgaelam = 0
        for t in reversed(range(T)):
            nonterminal = 1 - done[t + 1]
            delta = rew[t] + gamma * value[t + 1] * nonterminal - value[t]
            gaelam[
                t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
        ret = gaelam + value[:-1]
        return gaelam, ret

    def add_vtarg_and_adv_vtrace(seg,
                                 gamma,
                                 value,
                                 rho,
                                 trunc_rho,
                                 acdim=None):
        """
        Compute target value using TD(lambda) estimator, and advantage with GAE(lambda)
        """
        done = np.append(
            seg["done"], 0
        )  # last element is only used for last vtarg, but we already zeroed it if last new = 1
        rho_ = np.append(rho, 1.0)
        if acdim is not None:
            rho_ = np.exp(np.log(rho_) / acdim)

        r = np.minimum(trunc_rho, rho_)
        c = lam * np.minimum(1.0, rho_)
        T = len(seg["rew"])
        gaelam = np.empty(T, 'float32')
        gaelam2 = np.empty(T, 'float32')
        rew = runner.rewfilt(seg["rew"])
        lastgaelam = 0
        for t in reversed(range(T)):
            nonterminal = 1 - done[t + 1]
            delta = (rew[t] + gamma * value[t + 1] * nonterminal - value[t])
            gaelam[t] = delta + gamma * lam * nonterminal * lastgaelam
            lastgaelam = r[t] * gaelam[t]
        ret = r[:-1] * gaelam + value[:-1]
        adv = rew + gamma * (1.0 - done[1:]) * np.hstack([ret[1:], value[T]
                                                          ]) - value[:-1]
        return adv, ret, gaelam

    def add_vtarg_and_adv_vtrace4(seg,
                                  gamma,
                                  value,
                                  rho,
                                  trunc_rho,
                                  acdim=None):
        """
        Compute target value using TD(lambda) estimator, and advantage with GAE(lambda)
        """
        done = np.append(
            seg["done"], 0
        )  # last element is only used for last vtarg, but we already zeroed it if last new = 1
        rho_ = np.append(rho, 1.0)
        if acdim is not None:
            rho_ = np.exp(np.log(rho_) / acdim)

        T = len(seg["rew"])
        gaelam = np.zeros(T, 'float32')
        rew = runner.rewfilt(seg["rew"])
        delta = (rew + gamma * value[1:] * (1.0 - done[1:]) - value[:-1])
        gamlam = np.zeros(T, 'float32')
        for i in range(T):
            gamlam[i] = (gamma * lam)**i
        idx = T
        c = np.ones(T)
        for t in reversed(range(T)):
            # print(delta2)
            for j in range(t, T):
                if done[j + 1]:
                    idx = j + 1
                break
            gaelam[t] = np.sum(gamlam[:idx - t] *
                               (np.minimum(1.0, c) * delta)[t:idx])
            c[t:] = rho_[t] * c[t:]

        ret = np.minimum(trunc_rho, rho_[:-1]) * gaelam + value[:-1]
        adv = rew + gamma * (1.0 - done[1:]) * np.hstack([ret[1:], value[T]
                                                          ]) - value[:-1]
        return adv, ret, gaelam

    seg = None
    cliprangenow = cliprange(1.0)
    klconst = 1.0
    for update in range(1, nupdates + 1):
        assert nbatch % nminibatches == 0
        # Start timer
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        # Calculate the learning rate
        lrnow = np.maximum(1e-4, lr(frac))
        # Calculate the cliprange

        # Get minibatch
        if seg is None:
            prev_seg = seg
            seg = {}
        else:
            prev_seg = {}
            for i in seg:
                prev_seg[i] = np.copy(seg[i])
        seg["ob"], seg["rew"], seg["done"], seg["ac"], seg["neglogp"], seg[
            "mean"], seg[
                "logstd"], final_obs, final_done, epinfos = runner.run()  #pylint: disable=E0632
        # print(np.shape(seg["ob"]))
        if prev_seg is not None:
            for key in seg:
                if len(np.shape(seg[key])) == 1:
                    seg[key] = np.hstack([prev_seg[key], seg[key]])
                else:
                    seg[key] = np.vstack([prev_seg[key], seg[key]])
                if np.shape(seg[key])[0] > ERlen * nsteps:
                    seg[key] = seg[key][-ERlen * nsteps:]

        ob_stack = np.vstack([seg["ob"], final_obs])
        values = model.values(runner.obfilt(ob_stack))
        values[-1] = (1.0 - final_done) * values[-1]
        ob = runner.obfilt(seg["ob"])
        mean_now, logstd_now = model.meanlogstds(ob)
        # print(np.shape(seg["ac"])[1])
        neglogpnow = 0.5 * np.sum(np.square((seg["ac"] - mean_now) / np.exp(logstd_now)), axis=-1) \
                      + 0.5 * np.log(2.0 * np.pi) * np.shape(seg["ac"])[1] \
                      + np.sum(logstd_now, axis=-1)
        rho = np.exp(-neglogpnow + seg["neglogp"])
        # print(len(mean_now))
        # print(cliprangenow)
        # print(rho)
        if vtrace == 1:
            adv, ret, gae = add_vtarg_and_adv_vtrace(seg, gamma, values, rho,
                                                     trunc_rho)
            if useadv:
                gae = adv
        elif vtrace == 4:
            adv, ret, gae = add_vtarg_and_adv_vtrace4(seg, gamma, values, rho,
                                                      trunc_rho)
            if useadv:
                gae = adv
        else:
            gae, ret = add_vtarg_and_adv(seg, gamma, values, lam)
        r = np.minimum(1.0, rho)
        r_gae = gae * r
        print("======")
        print(gae)
        print(r_gae)
        print(gae.mean())
        print(r_gae.mean())
        print(gae.std())
        print(r_gae.std())
        print(r.mean())
        print("======")

        if eval_env is not None:
            eval_obs, eval_returns, eval_masks, eval_actions, _, _, eval_epinfos = eval_runner.run(
            )  #pylint: disable=E0632
        prior_row = np.zeros(len(seg["ob"]))
        temp_prior = []
        for i in range(int(len(prior_row) / nsteps)):
            temp_row = np.mean(
                np.abs(rho[i * nsteps:(i + 1) * nsteps] - 1.0) + 1.0)
            # local_rho[i + (ERlen-int(len(prior_row)/nsteps))].append(temp_row)
            if temp_row > 1 + clipcut:
                prior_row[i * nsteps:(i + 1) * nsteps] = 0
            else:
                prior_row[i * nsteps:(i + 1) * nsteps] = 1
            temp_prior.append(temp_row)
        print(temp_prior)

        # for i in range(len(prior_row)):
        #     if (np.abs(rho[i] - 1.0) + 1.0)>1.05:
        #         prior_row[i]=0
        #     else:
        #         prior_row[i]=1
        # for i in range(len(prior_row)):
        #     if rho[i]>1.1 :
        #         prior_row[i]=0
        #     else:
        #         prior_row[i]=1
        # prob = prior_row/np.sum(prior_row)

        print(np.sum(prior_row))

        epinfobuf.extend(epinfos)
        if eval_env is not None:
            eval_epinfobuf.extend(eval_epinfos)

        # Here what we're going to do is for each minibatch calculate the loss and append it.
        mblossvals = []
        # Index of each element of batch_size
        # Create the indices array

        inds1 = np.arange(len(seg["ob"]) - nsteps)
        inds2 = np.arange(nsteps) + len(seg["ob"]) - nsteps
        print(len(seg["ob"]))
        print(cliprangenow)
        nbatch_adapt1 = int(
            (np.sum(prior_row) - nsteps) / nsteps * nbatch_train)
        nbatch_adapt2 = int((nsteps) / nsteps * nbatch_train)
        print(rho)
        idx1 = []
        idx2 = []
        kl_rest = np.ones(len(seg["ob"])) * np.sum(prior_row) / nsteps
        kl_rest[:-nsteps] = 0
        # print(kl_rest)
        for _ in range(noptepochs):
            # Randomize the indexes
            # np.random.shuffle(inds)
            # 0 to batch_size with batch_train_size step

            # print(nbatch_adapt)
            losses_epoch = []
            for _ in range(int(nsteps / nbatch_train)):
                if nbatch_adapt1 > 0:
                    idx1 = np.random.choice(inds1,
                                            nbatch_adapt1,
                                            p=prior_row[:-2048] /
                                            np.sum(prior_row[:-2048]))
                idx2 = np.random.choice(inds2, nbatch_adapt2)
                # print(np.mean(np.abs(rho[mbinds] - 1.0) + 1.0))
                idx = np.hstack([idx1, idx2]).astype(int)

                slices = (arr[idx]
                          for arr in (ob, ret, gae, seg["done"], seg["ac"],
                                      values[:-1], seg["neglogp"], seg["mean"],
                                      seg["logstd"], kl_rest, rho, neglogpnow))
                loss_epoch = model.train(lrnow, cliprangenow, klconst, rgae,
                                         trunc_rho, *slices)
                mblossvals.append(loss_epoch)
                losses_epoch.append(loss_epoch)

            # # print(np.mean(losses_epoch, axis=0))
            # mean_n, logstd_n = model.meanlogstds(runner.obfilt(seg["ob"]))
            # # print(np.shape(seg["ac"])[1])
            # rho_after =  np.exp(- 0.5 * np.square((seg["ac"] - mean_n) / np.exp(logstd_n)) \
            #              - logstd_n + 0.5 * np.square((seg["ac"] - seg["mean"]) / np.exp(seg["logstd"]))\
            #              + seg["logstd"])
            # temp_ = []
            # for i in range(int(len(prior_row) / nsteps)):
            #     temp_row = np.mean(np.abs(rho_after[i * nsteps:(i + 1) * nsteps] - 1.0) + 1.0)
            #     # local_rho[i + (ERlen-int(len(prior_row)/nsteps))].append(temp_row)
            #     temp_.append(temp_row)
            # print(temp_)

        # Feedforward --> get losses --> update
        lossvals = np.mean(mblossvals, axis=0)
        if adaptive_kl:
            print("KL avg :", lossvals[3])
            if lossvals[3] > dtarg * 1.5:
                klconst *= 2
                print("kl const is increased")
            elif lossvals[3] < dtarg / 1.5:
                klconst /= 2
                print("kl const is reduced")
            klconst = np.clip(klconst, 2**(-10), 64)
        # End timer
        tnow = time.time()
        # Calculate the fps (frame per second)
        fps = int(nbatch / (tnow - tstart))
        if update % log_interval == 0 or update == 1:
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = explained_variance(values[:-1], ret)
            logger.logkv("batch IS weight",
                         [int(1000 * s) / 1000. for s in np.array(temp_prior)])
            logger.logkv("kl const", klconst)
            logger.logkv("clipping factor", cliprangenow)
            logger.logkv("learning rate", lrnow)
            logger.logkv("serial_timesteps", update * nsteps)
            logger.logkv("nupdates", update)
            logger.logkv("total_timesteps", update * nbatch)
            logger.logkv("fps", fps)
            logger.logkv("explained_variance", float(ev))
            logger.logkv('eprewmean',
                         safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean',
                         safemean([epinfo['l'] for epinfo in epinfobuf]))
            if eval_env is not None:
                logger.logkv(
                    'eval_eprewmean',
                    safemean([epinfo['r'] for epinfo in eval_epinfos]))
                logger.logkv(
                    'eval_eplenmean',
                    safemean([epinfo['l'] for epinfo in eval_epinfos]))
            logger.logkv('time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv(lossname, lossval)
            if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
                logger.dumpkvs()
        if save_interval and (update % save_interval == 0
                              or update == 1) and logger.get_dir() and (
                                  MPI is None
                                  or MPI.COMM_WORLD.Get_rank() == 0):
            checkdir = osp.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = osp.join(checkdir, '%.5i' % update)
            print('Saving to', savepath)
            model.save(savepath)
    return model
Example #6
0
def learn(*,
          network,
          env,
          total_timesteps,
          eval_env=None,
          seed=None,
          nsteps=2048,
          ent_coef=0.0,
          lr=3e-4,
          vf_coef=0.5,
          max_grad_norm=0.5,
          gamma=0.99,
          lam=0.95,
          log_interval=1,
          nminibatches=4,
          noptepochs=4,
          cliprange=0.2,
          save_interval=0,
          load_path=None,
          model_fn=None,
          update_fn=None,
          init_fn=None,
          mpi_rank_weight=1,
          comm=None,
          **network_kwargs):
    '''
    Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)

    Parameters:
    ----------

    network:                          policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
                                      specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
                                      tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
                                      neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
                                      See common/models.py/lstm for more details on using recurrent nets in policies

    env: baselines.common.vec_env.VecEnv     environment. Needs to be vectorized for parallel environment simulation.
                                      The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class.


    nsteps: int                       number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
                                      nenv is number of environment copies simulated in parallel)

    total_timesteps: int              number of timesteps (i.e. number of actions taken in the environment)

    ent_coef: float                   policy entropy coefficient in the optimization objective

    lr: float or function             learning rate, constant or a schedule function [0,1] -> R+ where 1 is beginning of the
                                      training and 0 is the end of the training.

    vf_coef: float                    value function loss coefficient in the optimization objective

    max_grad_norm: float or None      gradient norm clipping coefficient

    gamma: float                      discounting factor

    lam: float                        advantage estimation discounting factor (lambda in the paper)

    log_interval: int                 number of timesteps between logging events

    nminibatches: int                 number of training minibatches per update. For recurrent policies,
                                      should be smaller or equal than number of environments run in parallel.

    noptepochs: int                   number of training epochs per update

    cliprange: float or function      clipping range, constant or schedule function [0,1] -> R+ where 1 is beginning of the training
                                      and 0 is the end of the training

    save_interval: int                number of timesteps between saving events

    load_path: str                    path to load the model from

    **network_kwargs:                 keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
                                      For instance, 'mlp' network architecture has arguments num_hidden and num_layers.



    '''
    print(f"load_path is {load_path}")
    set_global_seeds(seed)

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    policy = build_policy(env, network, **network_kwargs)

    # Get the nb of env
    nenvs = env.num_envs // env.sides
    nminibatches = nenvs

    # Get state_space and action_space
    ob_space = env.observation_space
    ac_space = env.action_space

    # Calculate the batch_size
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches
    is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0)

    # Instantiate the model object (that creates act_model and train_model)
    if model_fn is None:
        from baselines.ppo2.model import Model
        model_fn = Model
    print({
        'ob_space': ob_space,
        'ac_space': ac_space,
        'nenvs': nenvs,
        'nbatch_train': nbatch_train,
        'nsteps': nsteps
    })

    model = model_fn(policy=policy,
                     ob_space=ob_space,
                     ac_space=ac_space,
                     nbatch_act=nenvs,
                     nbatch_train=nbatch_train,
                     nsteps=nsteps,
                     ent_coef=ent_coef,
                     vf_coef=vf_coef,
                     max_grad_norm=max_grad_norm,
                     comm=comm,
                     mpi_rank_weight=mpi_rank_weight,
                     side=0)
    if total_timesteps == 0:
        return model
    model_opponents = [
        model_fn(policy=policy,
                 ob_space=ob_space,
                 ac_space=ac_space,
                 nbatch_act=nenvs,
                 nbatch_train=nbatch_train,
                 nsteps=nsteps,
                 ent_coef=ent_coef,
                 vf_coef=vf_coef,
                 max_grad_norm=max_grad_norm,
                 comm=comm,
                 mpi_rank_weight=mpi_rank_weight,
                 side=i) for i in range(1, env.sides)
    ]
    #Added opponents
    if load_path is not None:
        model.load(load_path)
    # Instantiate the runner object
    runner = Runner(env=env,
                    model=model,
                    model_opponents=model_opponents,
                    nsteps=nsteps,
                    gamma=gamma,
                    lam=lam)
    if eval_env is not None:
        eval_runner = Runner(env=eval_env,
                             model=model,
                             model_opponents=model_opponents,
                             nsteps=nsteps,
                             gamma=gamma,
                             lam=lam)

    epinfobuf = deque(maxlen=100)
    opponents_epinfobuf = [deque(maxlen=100) for _ in range(env.sides)]
    if eval_env is not None:
        eval_epinfobuf = deque(maxlen=100)

    if init_fn is not None:
        init_fn()

    # Start total timer
    tfirststart = time.perf_counter()

    nupdates = total_timesteps // nbatch
    for update in tqdm(range(1, nupdates + 1)):
        assert nbatch % nminibatches == 0
        # Start timer
        tstart = time.perf_counter()
        frac = 1.0 - (update - 1.0) / nupdates
        # Calculate the learning rate
        lrnow = lr(frac)
        # Calculate the cliprange
        cliprangenow = cliprange(frac)

        if update % log_interval == 0 and is_mpi_root:
            logger.info('Stepping environment...')

        # Get minibatch
        obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run(
            1 - (update - 1) / nupdates)  #pylint: disable=E0632
        if eval_env is not None:
            eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_states, eval_epinfos = eval_runner.run(
                1 - (update - 1) / nupdates)  #pylint: disable=E0632

        if update % log_interval == 0 and is_mpi_root: logger.info('Done.')

        epinfobuf.extend(epinfos[0::env.sides])
        for i in range(env.sides - 1):
            opponents_epinfobuf[i].extend(epinfos[i + 1::env.sides])
        if eval_env is not None:
            eval_epinfobuf.extend(eval_epinfos)

        # Here what we're going to do is for each minibatch calculate the loss and append it.
        mblossvals = []
        # mb_opponent_lossvals = []
        if states is None:  # nonrecurrent version
            # Index of each element of batch_size
            # Create the indices array
            inds = np.arange(nbatch)
            for _ in range(noptepochs):
                # Randomize the indexes
                np.random.shuffle(inds)
                # 0 to batch_size with batch_train_size step
                for start in range(0, nbatch, nbatch_train):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (arr[mbinds]
                              for arr in (obs, returns, masks, actions, values,
                                          neglogpacs))
                    mblossvals.append(model.train(lrnow, cliprangenow,
                                                  *slices))
                    # mb_opponent_lossval = []
                    # for i in range(env.sides-1):
                    #     mb_opponent_lossval.append(model_opponents[i].train(lrnow, cliprangenow, *slices))
                    # mb_opponent_lossvals.append(mb_opponent_lossval)
        else:  # recurrent version
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            for _ in range(noptepochs):
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = (arr[mbflatinds]
                              for arr in (obs, returns, masks, actions, values,
                                          neglogpacs))
                    mbstates = states[mbenvinds]
                    mblossvals.append(
                        model.train(lrnow, cliprangenow, *slices, mbstates))
                    # mb_opponent_lossval = []
                    # for i in range(env.sides-1):
                    #     mb_opponent_lossval.append(model_opponents[i].train(lrnow, cliprangenow, *slices, mbstates))
                    # mb_opponent_lossvals.append(mb_opponent_lossval)

        # Feedforward --> get losses --> update
        lossvals = np.mean(mblossvals, axis=0)
        # End timer
        tnow = time.perf_counter()
        # Calculate the fps (frame per second)
        fps = int(nbatch / (tnow - tstart))

        if update_fn is not None:
            update_fn(update)

        if update % log_interval == 0 or update == 1:
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            wandb_log_dic = {}
            ev = explained_variance(values, returns)
            logger.logkv("misc/serial_timesteps", update * nsteps)
            logger.logkv("misc/nupdates", update)
            logger.logkv("misc/total_timesteps", update * nbatch)
            logger.logkv("fps", fps)
            logger.logkv("misc/explained_variance", float(ev))
            wandb_log_dic["misc/explained_variance"] = float(ev)
            if eval_env is not None:
                logger.logkv(
                    'eval_eprewmean',
                    safemean([epinfo['r'] for epinfo in eval_epinfobuf]))
                wandb_log_dic['eval_eprewmean'] = safemean(
                    [epinfo['r'] for epinfo in eval_epinfobuf])
                logger.logkv(
                    'eval_eplenmean',
                    safemean([epinfo['l'] for epinfo in eval_epinfobuf]))
                wandb_log_dic['eval_eplenmean'] = safemean(
                    [epinfo['l'] for epinfo in eval_epinfobuf])
            for key in epinfobuf[0]:
                logger.logkv('ep_' + key + '_mean',
                             safemean([epinfo[key] for epinfo in epinfobuf]))
            for i in range(env.sides - 1):
                for key in epinfobuf[0]:
                    logger.logkv(
                        f'opponent_{i}_ep_' + key + '_mean',
                        safemean([
                            epinfo[key] for epinfo in opponents_epinfobuf[i]
                        ]))

            logger.logkv('misc/time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv('loss/' + lossname, lossval)
                wandb_log_dic['loss/' + lossname] = lossval
            logger.dumpkvs()
        if save_interval and (update % save_interval == 0 or update
                              == 1) and logger.get_dir() and is_mpi_root:
            checkdir = osp.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = osp.join(checkdir, '%.5i' % update)
            print('Saving to', savepath)
            runner.save(savepath)
            model = runner.model
            model_opponents = runner.model_opponents

    return model
Example #7
0
def learn(*, network,
        env, eval_env, make_eval_env, env_id,
        total_timesteps, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
        vf_coef=0.5,  max_grad_norm=0.5, gamma=0.99, lam=0.95,
        log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
        sil_update=10, sil_value=0.01, sil_alpha=0.6, sil_beta=0.1, sil_loss=0.1,

        # MBL
        # For train mbl
        mbl_train_freq=5,
        # For eval
        num_eval_episodes=5,
        eval_freq=5,
        vis_eval=False,
#        eval_targs=('mbmf',),
        eval_targs=('mf',),
        quant=2,

        # For mbl.step
        #num_samples=(1500,),
        num_samples=(1,),
        horizon=(2,),
        #horizon=(2,1),
        #num_elites=(10,),
        num_elites=(1,),
        mbl_lamb=(1.0,),
        mbl_gamma=0.99,
        #mbl_sh=1, # Number of step for stochastic sampling
        mbl_sh=10000,
        #vf_lookahead=-1,
        #use_max_vf=False,
        reset_per_step=(0,),

        # For get_model
        num_fc=2,
        num_fwd_hidden=500,
        use_layer_norm=False,        

        # For MBL
        num_warm_start=int(1e4),            
        init_epochs=10, 
        update_epochs=5, 
        batch_size=512, 
        update_with_validation=False, 
        use_mean_elites=1,
        use_ent_adjust=0,
        adj_std_scale=0.5,

        # For data loading
        validation_set_path=None, 

        # For data collect
        collect_val_data=False,

        # For traj collect
        traj_collect='mf',      
        
        # For profile
        measure_time=True,
        eval_val_err=False,
        measure_rew=True,
        save_interval=0, load_path=None, model_fn=None, update_fn=None, init_fn=None, mpi_rank_weight=1, comm=None, **network_kwargs):
    '''
    Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)

    Parameters:
    ----------

    network:                          policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
                                      specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
                                      tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
                                      neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
                                      See common/models.py/lstm for more details on using recurrent nets in policies

    env: baselines.common.vec_env.VecEnv     environment. Needs to be vectorized for parallel environment simulation.
                                      The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class.


    nsteps: int                       number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
                                      nenv is number of environment copies simulated in parallel)

    total_timesteps: int              number of timesteps (i.e. number of actions taken in the environment)

    ent_coef: float                   policy entropy coefficient in the optimization objective

    lr: float or function             learning rate, constant or a schedule function [0,1] -> R+ where 1 is beginning of the
                                      training and 0 is the end of the training.

    vf_coef: float                    value function loss coefficient in the optimization objective

    max_grad_norm: float or None      gradient norm clipping coefficient

    gamma: float                      discounting factor

    lam: float                        advantage estimation discounting factor (lambda in the paper)

    log_interval: int                 number of timesteps between logging events

    nminibatches: int                 number of training minibatches per update. For recurrent policies,
                                      should be smaller or equal than number of environments run in parallel.

    noptepochs: int                   number of training epochs per update

    cliprange: float or function      clipping range, constant or schedule function [0,1] -> R+ where 1 is beginning of the training
                                      and 0 is the end of the training

    save_interval: int                number of timesteps between saving events

    load_path: str                    path to load the model from

    **network_kwargs:                 keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
                                      For instance, 'mlp' network architecture has arguments num_hidden and num_layers.



    '''
    if not isinstance(num_samples, tuple): num_samples = (num_samples,)
    if not isinstance(horizon, tuple): horizon = (horizon,)
    if not isinstance(num_elites, tuple): num_elites = (num_elites,)
    if not isinstance(mbl_lamb, tuple): mbl_lamb = (mbl_lamb,)
    if not isinstance(reset_per_step, tuple): reset_per_step = (reset_per_step,)
    if validation_set_path is None: 
        if collect_val_data: validation_set_path = os.path.join(logger.get_dir(), 'val.pkl')
        else: validation_set_path = os.path.join('dataset', '{}-val.pkl'.format(env_id))
    if eval_val_err:
        eval_val_err_path = os.path.join('dataset', '{}-combine-val.pkl'.format(env_id))
    logger.log(locals())
    logger.log('MBL_SH', mbl_sh)
    logger.log('Traj_collect', traj_collect)
    
    if MPI is not None:
        nworkers = MPI.COMM_WORLD.Get_size()
        rank = MPI.COMM_WORLD.Get_rank()
    else:
        nworkers = 1
        rank = 0  
    cpus_per_worker = 1
    U.get_session(config=tf.ConfigProto(
            allow_soft_placement=True,
            inter_op_parallelism_threads=cpus_per_worker,
            intra_op_parallelism_threads=cpus_per_worker
    ))

    set_global_seeds(seed)

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)
    
    policy = build_policy(env, network, **network_kwargs)
    np.set_printoptions(precision=3)
    # Get the nb of env
    nenvs = env.num_envs
    # Get state_space and action_space
    ob_space = env.observation_space
    ac_space = env.action_space

    # Calculate the batch_size
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches
    is_mpi_root = (MPI is None or MPI.COMM_WORLD.Get_rank() == 0)

    # Instantiate the model object (that creates act_model and train_model)
    if model_fn is None:
        model_fn = Model

    make_model = lambda: Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
                          nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
                          max_grad_norm=max_grad_norm,
                          sil_update=sil_update,
                          fn_reward=None, fn_obs=None,
                          sil_value=sil_value, 
                          sil_alpha=sil_alpha, 
                          sil_beta=sil_beta,
                          sil_loss=sil_loss,
                          comm=comm, mpi_rank_weight=mpi_rank_weight,
                          ppo=True,prev_pi=None)
    model=make_model()
    pi=model.sil_model
    
    if load_path is not None:
        model.load(load_path)

    # MBL
    # ---------------------------------------
    #viz = Visdom(env=env_id) 
    win = None
    eval_targs = list(eval_targs)
    logger.log(eval_targs)

    make_model_f = get_make_mlp_model(num_fc=num_fc, num_fwd_hidden=num_fwd_hidden, layer_norm=use_layer_norm)
    mbl = MBL(env=eval_env, env_id=env_id, make_model=make_model_f,
            num_warm_start=num_warm_start,            
            init_epochs=init_epochs, 
            update_epochs=update_epochs, 
            batch_size=batch_size, 
            **network_kwargs)

    val_dataset = {'ob': None, 'ac': None, 'ob_next': None}
    if update_with_validation:
        logger.log('Update with validation')
        val_dataset = load_val_data(validation_set_path)
    if eval_val_err:
        logger.log('Log val error')
        eval_val_dataset = load_val_data(eval_val_err_path)       
    if collect_val_data:
        logger.log('Collect validation data')
        val_dataset_collect = [] 

    def _mf_pi(ob, t=None):
        stochastic = True
        ac, vpred, _, _ = pi.step(ob, stochastic=stochastic)
        return ac, vpred   
    def _mf_det_pi(ob, t=None):
        #ac, vpred, _, _ = pi.step(ob, stochastic=False)
        ac, vpred = pi._evaluate([pi.pd.mode(), pi.vf], ob)        
        return ac, vpred
    def _mf_ent_pi(ob, t=None):
        mean, std, vpred = pi._evaluate([pi.pd.mode(), pi.pd.std, pi.vf], ob)
        ac = np.random.normal(mean, std * adj_std_scale, size=mean.shape)
        return ac, vpred
################### use_ent_adjust======> adj_std_scale????????pi action sample
    def _mbmf_inner_pi(ob, t=0):
        if use_ent_adjust:
            return _mf_ent_pi(ob)
        else:
            #return _mf_pi(ob)
            if t < mbl_sh: return _mf_pi(ob)        
            else: return _mf_det_pi(ob)

   # ---------------------------------------   
 
    # Run multiple configuration once
    all_eval_descs = []
    def make_mbmf_pi(n, h, e, l):
        def _mbmf_pi(ob):                        
            ac, rew = mbl.step(ob=ob, pi=_mbmf_inner_pi, horizon=h, num_samples=n, num_elites=e, gamma=mbl_gamma, lamb=l, use_mean_elites=use_mean_elites) 
            return ac[None], rew
        return Policy(step=_mbmf_pi, reset=None)

    for n in num_samples:
        for h in horizon:
            for l in mbl_lamb:
                for e in num_elites:                     
                    if 'mbmf' in eval_targs: all_eval_descs.append(('MeanRew', 'MBL_PPO_SIL', make_mbmf_pi(n, h, e, l)))
                    #if 'mbmf' in eval_targs: all_eval_descs.append(('MeanRew-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), 'MBL_TRPO-n-{}-h-{}-e-{}-l-{}-sh-{}-me-{}'.format(n, h, e, l, mbl_sh, use_mean_elites), make_mbmf_pi(n, h, e, l)))                   
    if 'mf' in eval_targs: all_eval_descs.append(('MeanRew', 'PPO_SIL', Policy(step=_mf_pi, reset=None)))
   
    logger.log('List of evaluation targets')
    for it in all_eval_descs:
        logger.log(it[0])    

    @contextmanager
    def timed(msg):
        if rank == 0:
            print(colorize(msg, color='magenta'))
            tstart = time.time()
            yield
            print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta'))
        else:
            yield

    pool = Pool(mp.cpu_count())
    warm_start_done = False
    U.initialize()
    if load_path is not None:
        pi.load(load_path)

    # Instantiate the runner object
    runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)
    epinfobuf = deque(maxlen=40)
    if init_fn is not None: init_fn()

    if traj_collect == 'mf':
        obs= runner.run()[0]
    
    # Start total timer
    tfirststart = time.perf_counter()

    nupdates = total_timesteps//nbatch
    for update in range(1, nupdates+1):
        assert nbatch % nminibatches == 0
        # Start timer
        if hasattr(model.train_model, "ret_rms"):
            model.train_model.ret_rms.update(returns)
        if hasattr(model.train_model, "rms"):
            model.train_model.rms.update(obs)
        tstart = time.perf_counter()
        frac = 1.0 - (update - 1.0) / nupdates
        # Calculate the learning rate
        lrnow = lr(frac)
        # Calculate the cliprange
        cliprangenow = cliprange(frac)

        if update % log_interval == 0 and is_mpi_root: logger.info('Stepping environment...')

        # Get minibatch
        obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632
        
        # Val data collection
        if collect_val_data:
            for ob_, ac_, ob_next_ in zip(obs[:-1, 0, ...], actions[:-1, ...], obs[1:, 0, ...]):            
                val_dataset_collect.append((copy.copy(ob_), copy.copy(ac_), copy.copy(ob_next_)))
        # -----------------------------
        # MBL update
        else:
            ob_mbl, ac_mbl = obs.copy(), actions.copy()
        
            mbl.add_data_batch(ob_mbl[:-1, ...], ac_mbl[:-1, ...], ob_mbl[1:, ...])
            mbl.update_forward_dynamic(require_update=(update-1) % mbl_train_freq == 0, 
                    ob_val=val_dataset['ob'], ac_val=val_dataset['ac'], ob_next_val=val_dataset['ob_next'])            
        # -----------------------------
        
        if update % log_interval == 0 and is_mpi_root: logger.info('Done.')

        epinfobuf.extend(epinfos)

        # Here what we're going to do is for each minibatch calculate the loss and append it.
        mblossvals = []
        if states is None: # nonrecurrent version
            # Index of each element of batch_size
            # Create the indices array
            inds = np.arange(nbatch)
            for _ in range(noptepochs):
                # Randomize the indexes
                np.random.shuffle(inds)
                # 0 to batch_size with batch_train_size step
                for start in range(0, nbatch, nbatch_train):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
                    mblossvals.append(model.train(lrnow, cliprangenow, *slices))
            l_loss, sil_adv, sil_samples, sil_nlogp = model.sil_train(lrnow)
            
        else: # recurrent version
            print("caole")
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            for _ in range(noptepochs):
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = (arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
                    mbstates = states[mbenvinds]
                    mblossvals.append(model.train(lrnow, cliprangenow, *slices, mbstates))

        # Feedforward --> get losses --> update
        lossvals = np.mean(mblossvals, axis=0)
        # End timer
        tnow = time.perf_counter()
        # Calculate the fps (frame per second)
        fps = int(nbatch / (tnow - tstart))

        if update_fn is not None:
            update_fn(update)

        if update % log_interval == 0 or update == 1:
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = explained_variance(values, returns)
            logger.logkv("misc/serial_timesteps", update*nsteps)
            logger.logkv("misc/nupdates", update)
            logger.logkv("misc/total_timesteps", update*nbatch)
            logger.logkv("fps", fps)
            logger.logkv("misc/explained_variance", float(ev))
            logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv("AverageReturn", safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
            logger.logkv('misc/time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv('loss/' + lossname, lossval)
            if sil_update > 0:
                logger.logkv("sil_samples", sil_samples)
                    
            if rank==0:
                # MBL evaluation
                if not collect_val_data:
                    #set_global_seeds(seed)
                    default_sess = tf.get_default_session()
                    def multithread_eval_policy(env_, pi_, num_episodes_, vis_eval_,seed):
                        with default_sess.as_default():
                            if hasattr(env, 'ob_rms') and hasattr(env_, 'ob_rms'):
                                env_.ob_rms = env.ob_rms 
                            res = eval_policy(env_, pi_, num_episodes_, vis_eval_, seed, measure_time, measure_rew) 
                            
                            try:
                                env_.close()
                            except:
                                pass
                        return res

                    if mbl.is_warm_start_done() and update % eval_freq == 0:
                        warm_start_done = mbl.is_warm_start_done()
                        if num_eval_episodes > 0 :
                            targs_names = {}
                            with timed('eval'):
                                num_descs = len(all_eval_descs)
                                list_field_names = [e[0] for e in all_eval_descs]
                                list_legend_names = [e[1] for e in all_eval_descs]
                                list_pis = [e[2] for e in all_eval_descs]                    
                                list_eval_envs = [make_eval_env() for _ in range(num_descs)]
                                list_seed= [seed for _ in range(num_descs)]
                                list_num_eval_episodes = [num_eval_episodes for _ in range(num_descs)]
                                print(list_field_names)
                                print(list_legend_names)
                                
                                list_vis_eval = [vis_eval for _ in range(num_descs)]

                                for i in range(num_descs):
                                    field_name, legend_name=list_field_names[i], list_legend_names[i],
                                    
                                    res= multithread_eval_policy(list_eval_envs[i], list_pis[i], list_num_eval_episodes[i], list_vis_eval[i], seed)
                                #eval_results = pool.starmap(multithread_eval_policy, zip(list_eval_envs, list_pis, list_num_eval_episodes, list_vis_eval,list_seed))
                                
                                #for field_name, legend_name, res in zip(list_field_names, list_legend_names, eval_results):
                                    perf, elapsed_time, eval_rew = res
                                    logger.logkv(field_name, perf)                    
                                    if measure_time: logger.logkv('Time-%s' % (field_name), elapsed_time)
                                    if measure_rew: logger.logkv('SimRew-%s' % (field_name), eval_rew)
                                    targs_names[field_name] = legend_name
        
                        if eval_val_err:
                            fwd_dynamics_err = mbl.eval_forward_dynamic(obs=eval_val_dataset['ob'], acs=eval_val_dataset['ac'], obs_next=eval_val_dataset['ob_next'])        
                            logger.logkv('FwdValError', fwd_dynamics_err)

                        #logger.dump_tabular()
                        logger.dumpkvs()
                        #print(logger.get_dir())
                        #print(targs_names)
                        #if num_eval_episodes > 0:
#                            win = plot(viz, win, logger.get_dir(), targs_names=targs_names, quant=quant, opt='best')
                    #else: logger.dumpkvs()
                # -----------
            yield pi
            
        if collect_val_data:
            with open(validation_set_path, 'wb') as f:
                pickle.dump(val_dataset_collect, f)
            logger.log('Save {} validation data'.format(len(val_dataset_collect)))
        if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir() and is_mpi_root:
            checkdir = osp.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = osp.join(checkdir, '%.5i'%update)
            print('Saving to', savepath)
            model.save(savepath)
        

    return model
Example #8
0
def learn(*, network, env, total_timesteps, eval_env=None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4, vf_coef=0.5,
          max_grad_norm=0.5, gamma=0.99, lam=0.95, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
          save_interval=1, load_path=None, model_fn=None, update_fn=None, init_fn=None, mpi_rank_weight=1, comm=None,
          nagent=1, anneal_bound=500, **network_kwargs):
    '''
    Learn policy using PPO algorithm (https://arxiv.org/abs/1707.06347)

    Parameters:
    ----------

    network:                          policy network architecture. Either string (mlp, lstm, lnlstm, cnn_lstm, cnn, cnn_small, conv_only - see baselines.common/models.py for full list)
                                      specifying the standard network architecture, or a function that takes tensorflow tensor as input and returns
                                      tuple (output_tensor, extra_feed) where output tensor is the last network layer output, extra_feed is None for feed-forward
                                      neural nets, and extra_feed is a dictionary describing how to feed state into the network for recurrent neural nets.
                                      See common/models.py/lstm for more details on using recurrent nets in policies

    env: baselines.common.vec_env.VecEnv     environment. Needs to be vectorized for parallel environment simulation.
                                      The environments produced by gym.make can be wrapped using baselines.common.vec_env.DummyVecEnv class.


    nsteps: int                       number of steps of the vectorized environment per update (i.e. batch size is nsteps * nenv where
                                      nenv is number of environment copies simulated in parallel)

    total_timesteps: int              number of timesteps (i.e. number of actions taken in the environment)

    ent_coef: float                   policy entropy coefficient in the optimization objective

    lr: float or function             learning rate, constant or a schedule function [0,1] -> R+ where 1 is beginning of the
                                      training and 0 is the end of the training.

    vf_coef: float                    value function loss coefficient in the optimization objective

    max_grad_norm: float or None      gradient norm clipping coefficient

    gamma: float                      discounting factor

    lam: float                        advantage estimation discounting factor (lambda in the paper)

    log_interval: int                 number of timesteps between logging events

    nminibatches: int                 number of training minibatches per update. For recurrent policies,
                                      should be smaller or equal than number of environments run in parallel.

    noptepochs: int                   number of training epochs per update

    cliprange: float or function      clipping range, constant or schedule function [0,1] -> R+ where 1 is beginning of the training
                                      and 0 is the end of the training

    save_interval: int                number of timesteps between saving events

    load_path: str                    path to load the model from

    nagent: int                       number of agents in an environment

    anneal_bound: int                 the number of iterations it takes for dense reward anneal to 0

    **network_kwargs:                 keyword arguments to the policy / network builder. See baselines.common/policies.py/build_policy and arguments to a particular type of network
                                      For instance, 'mlp' network architecture has arguments num_hidden and num_layers.



    '''

    set_global_seeds(seed)

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    policy = build_policy(env, network, **network_kwargs)

    # Get the nb of env
    nenvs = env.num_envs

    # Get state_space and action_space
    ob_space = env.observation_space[0]
    ac_space = env.action_space[0]

    # Calculate the batch_size
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches

    # Instantiate the model object (that creates act_model and train_model)
    if model_fn is None:
        from model import Model
        model_fn = Model

    model = model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
                     nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm,
                     model_scope='model_%d' % 0)
    models = [model]
    for i in range(1, nagent):
        models.append(
            model_fn(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
                     nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, trainable=False,
                     model_scope='model_%d' % i))
    writer = tf.summary.FileWriter(logger.get_dir(), tf.get_default_session().graph)

    if load_path is not None:
        for i in range(nagent):
            models[i].load(load_path)

    # Instantiate the runner object
    runner = Runner(env=env, models=models, nsteps=nsteps, nagent=nagent, gamma=gamma, lam=lam, anneal_bound=anneal_bound)
    if eval_env is not None:
        eval_runner = Runner(env=eval_env, models=models, nsteps=nsteps, nagent=nagent, gamma=gamma, lam=lam)
    epinfobuf = deque(maxlen=100)
    if eval_env is not None:
        eval_epinfobuf = deque(maxlen=100)

    if init_fn is not None:
        init_fn()

    # Start total timer
    tfirststart = time.perf_counter()
    checkdir = osp.join(logger.get_dir(), 'checkpoints')

    # number of iterations
    nupdates = total_timesteps//nbatch
    for update in range(1, nupdates+1):
        print('Iteration: %d/%d' % (update, nupdates))
        assert nbatch % nminibatches == 0
        # Start timer
        tstart = time.perf_counter()
        frac = 1.0 - (update - 1.0) / nupdates
        # Calculate the learning rate
        lrnow = lr(frac)
        # Calculate the cliprange
        cliprangenow = cliprange(frac)

        # Set opponents' model
        if update == 1:
            if update % log_interval == 0:
                logger.info('Stepping environment...Compete with random opponents')
        else:
            # different environment get different opponent model
            # all parallel environments get same opponent model
            old_versions = [round(np.random.uniform(1, update - 1)) for _ in range(nagent - 1)]
            old_model_paths = [osp.join(checkdir, '%.5i' % old_id) for old_id in old_versions]
            for i in range(1, nagent):
                runner.models[i].load(old_model_paths[i - 1])
            if update % log_interval == 0:
                logger.info('Stepping environment...Compete with', ', '.join([str(old_id) for old_id in old_versions]))

        # Get minibatch
        obs, returns, masks, actions, values, neglogpacs, rewards, states, epinfos = runner.run(update)
        if eval_env is not None:
            eval_obs, eval_returns, eval_masks, eval_actions, eval_values, eval_neglogpacs, eval_rewards, \
            eval_states, eval_epinfos = eval_runner.run()

        if update % log_interval == 0:
            logger.info('Done.')

        epinfobuf.extend(epinfos)
        if eval_env is not None:
            eval_epinfobuf.extend(eval_epinfos)

        # Here what we're going to do is for each minibatch calculate the loss and append it.
        mblossvals = []
        if states is None: # nonrecurrent version
            # Index of each element of batch_size
            # Create the indices array
            inds = np.arange(nbatch)
            for epoch in range(noptepochs):
                # Randomize the indexes
                np.random.shuffle(inds)
                # 0 to batch_size with batch_train_size step
                for ii, start in enumerate(range(0, nbatch, nbatch_train)):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs, rewards))
                    temp_out = model.train(lrnow, cliprangenow, *slices)
                    writer.add_summary(temp_out[-1], (update - 1) * noptepochs * nminibatches + epoch * nminibatches + ii)
                    mblossvals.append(temp_out[:-1])
        else: # recurrent version
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            for epoch in range(noptepochs):
                np.random.shuffle(envinds)
                for ii, start in enumerate(range(0, nbatch, nbatch_train)):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = (arr[mbflatinds] for arr in (obs, returns, masks, actions, values, neglogpacs, rewards))
                    mbstates = states[mbenvinds]
                    temp_out = model.train(lrnow, cliprangenow, *slices, mbstates)
                    writer.add_summary(temp_out[-1], (update - 1) * noptepochs * nminibatches + epoch * nminibatches + ii)
                    mblossvals.append(temp_out[:-1])

        # Feedforward --> get losses --> update
        lossvals = np.mean(mblossvals, axis=0)
        # End timer
        tnow = time.perf_counter()

        if update_fn is not None:
            update_fn(update)

        if update % log_interval == 0 or update == 1:
            # Calculates if value function is a good predicator of the returns (ev > 1)
            # or if it's just worse than predicting nothing (ev =< 0)
            ev = explained_variance(values, returns)
            logger.logkv("misc/serial_timesteps", update*nsteps)
            logger.logkv("misc/nupdates", update)
            logger.logkv("misc/total_timesteps", update*nbatch)
            logger.logkv("misc/explained_variance", float(ev))
            logger.logkv('eprewmean', safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('epdenserewmean', safemean([epinfo['dr'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]))
            # if eval_env is not None:
            #     logger.logkv('eval_eprewmean', safemean([epinfo['r'] for epinfo in eval_epinfobuf]) )
            #     logger.logkv('eval_eplenmean', safemean([epinfo['l'] for epinfo in eval_epinfobuf]) )
            logger.logkv('misc/time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv('loss/' + lossname, lossval)

            logger.dumpkvs()
        if save_interval and (update % save_interval == 0 or update == 1) and logger.get_dir():
            checkdir = osp.join(logger.get_dir(), 'checkpoints')
            os.makedirs(checkdir, exist_ok=True)
            savepath = osp.join(checkdir, '%.5i'%update)
            print('Saving to', savepath)
            model.save(savepath)

    writer.close()
    return model