Пример #1
0
def learn(policy,
          env,
          nsteps,
          total_timesteps,
          ent_coef,
          lr,
          vf_coef=0.5,
          max_grad_norm=0.5,
          gamma=0.99,
          lam=0.95,
          log_interval=10,
          nminibatches=4,
          noptepochs=4,
          cliprange=0.2,
          save_interval=0,
          load_path=None,
          train_callback=None,
          eval_callback=None,
          cloud_sync_callback=None,
          cloud_sync_interval=1000,
          workdir='',
          use_curiosity=False,
          curiosity_strength=0.01,
          forward_inverse_ratio=0.2,
          curiosity_loss_strength=10,
          random_state_predictor=False):
    if isinstance(lr, float):
        lr = constfn(lr)
    else:
        assert callable(lr)
    if isinstance(cliprange, float):
        cliprange = constfn(cliprange)
    else:
        assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches

    # pylint: disable=g-long-lambda
    make_model = lambda: Model(policy=policy,
                               ob_space=ob_space,
                               ac_space=ac_space,
                               nbatch_act=nenvs,
                               nbatch_train=nbatch_train,
                               nsteps=nsteps,
                               ent_coef=ent_coef,
                               vf_coef=vf_coef,
                               max_grad_norm=max_grad_norm,
                               use_curiosity=use_curiosity,
                               curiosity_strength=curiosity_strength,
                               forward_inverse_ratio=forward_inverse_ratio,
                               curiosity_loss_strength=curiosity_loss_strength,
                               random_state_predictor=random_state_predictor)
    # pylint: enable=g-long-lambda
    if save_interval and workdir:
        with tf.gfile.Open(osp.join(workdir, 'make_model.pkl'), 'wb') as fh:
            fh.write(dill.dumps(make_model))
    model = make_model()
    if load_path is not None:
        model.load(load_path)
    runner = Runner(env=env,
                    model=model,
                    nsteps=nsteps,
                    gamma=gamma,
                    lam=lam,
                    eval_callback=eval_callback)

    epinfobuf = deque(maxlen=100)
    tfirststart = time.time()

    nupdates = total_timesteps // nbatch
    for update in range(1, nupdates + 1):
        assert nbatch % nminibatches == 0
        nbatch_train = nbatch // nminibatches
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)
        (obs, next_obs, returns, masks, actions, values,
         neglogpacs), states, epinfos = runner.run()
        epinfobuf.extend(epinfos)
        mblossvals = []
        if states is None:  # nonrecurrent version
            inds = np.arange(nbatch)
            for _ in range(noptepochs):
                np.random.shuffle(inds)
                for start in range(0, nbatch, nbatch_train):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = [
                        arr[mbinds] for arr in (obs, returns, masks, actions,
                                                values, neglogpacs, next_obs)
                    ]
                    mblossvals.append(
                        model.train(lrnow, cliprangenow, slices[0], slices[6],
                                    slices[1], slices[2], slices[3], slices[4],
                                    slices[5]))
        else:  # recurrent version
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            envsperbatch = nbatch_train // nsteps
            for _ in range(noptepochs):
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = [
                        arr[mbflatinds]
                        for arr in (obs, returns, masks, actions, values,
                                    neglogpacs, next_obs)
                    ]
                    mbstates = states[mbenvinds]
                    mblossvals.append(
                        model.train(lrnow, cliprangenow, slices[0], slices[6],
                                    slices[1], slices[2], slices[3], slices[4],
                                    slices[5], mbstates))

        lossvals = np.mean(mblossvals, axis=0)
        tnow = time.time()
        fps = int(nbatch / (tnow - tstart))
        if update % log_interval == 0 or update == 1:
            ev = explained_variance(values, returns)
            logger.logkv('serial_timesteps', update * nsteps)
            logger.logkv('nupdates', update)
            logger.logkv('total_timesteps', update * nbatch)
            logger.logkv('fps', fps)
            logger.logkv('explained_variance', float(ev))
            logger.logkv('eprewmean',
                         safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean',
                         safemean([epinfo['l'] for epinfo in epinfobuf]))
            if train_callback:
                train_callback(safemean([epinfo['l'] for epinfo in epinfobuf]),
                               safemean([epinfo['r'] for epinfo in epinfobuf]),
                               update * nbatch)
            logger.logkv('time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv(lossname, lossval)
            logger.dumpkvs()
        if (save_interval and (update % save_interval == 0 or update == 1)
                and workdir):
            checkdir = osp.join(workdir, 'checkpoints')
            if not tf.gfile.Exists(checkdir):
                tf.gfile.MakeDirs(checkdir)
            savepath = osp.join(checkdir, '%.5i' % update)
            print('Saving to', savepath)
            model.save(savepath)
        if (cloud_sync_interval and update % cloud_sync_interval == 0
                and cloud_sync_callback):
            cloud_sync_callback()
    env.close()
    return model
Пример #2
0
 def eval_callback_on_test(eprewmean, global_step_val):
   if test_measurements:
     test_measurements.create_measurement(
         objective_value=eprewmean, step=global_step_val)
   logger.logkv('eprewmean_test', eprewmean)
Пример #3
0
 def eval_callback_on_valid(eprewmean, global_step_val):
   if valid_measurements:
     valid_measurements.create_measurement(
         objective_value=eprewmean, step=global_step_val)
   logger.logkv('eprewmean_valid', eprewmean)
Пример #4
0
 def measurement_callback(unused_eplenmean, eprewmean, global_step_val):
   if train_measurements:
     train_measurements.create_measurement(
         objective_value=eprewmean, step=global_step_val)
   logger.logkv('eprewmean_train', eprewmean)
Пример #5
0
def learn(policy,
          env,
          nsteps,
          total_timesteps,
          ent_coef,
          lr,
          vf_coef=0.5,
          max_grad_norm=0.5,
          gamma=0.99,
          lam=0.95,
          log_interval=10,
          nminibatches=4,
          noptepochs=4,
          cliprange=0.2,
          save_interval=0,
          load_path=None,
          train_callback=None,
          eval_callback=None,
          cloud_sync_callback=None,
          cloud_sync_interval=1000,
          workdir='',
          use_curiosity=False,
          curiosity_strength=0.01,
          forward_inverse_ratio=0.2,
          curiosity_loss_strength=10,
          random_state_predictor=False,
          use_rlb=False,
          checkpoint_path_for_debugging=None):
    if isinstance(lr, float):
        lr = constfn(lr)
    else:
        assert callable(lr)
    if isinstance(cliprange, float):
        cliprange = constfn(cliprange)
    else:
        assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps
    nbatch_train = nbatch // nminibatches

    # pylint: disable=g-long-lambda
    make_model = lambda: Model(policy=policy,
                               ob_space=ob_space,
                               ac_space=ac_space,
                               nbatch_act=nenvs,
                               nbatch_train=nbatch_train,
                               nsteps=nsteps,
                               ent_coef=ent_coef,
                               vf_coef=vf_coef,
                               max_grad_norm=max_grad_norm,
                               use_curiosity=use_curiosity,
                               curiosity_strength=curiosity_strength,
                               forward_inverse_ratio=forward_inverse_ratio,
                               curiosity_loss_strength=curiosity_loss_strength,
                               random_state_predictor=random_state_predictor,
                               use_rlb=use_rlb)
    # pylint: enable=g-long-lambda
    if save_interval and workdir:
        with tf.gfile.Open(osp.join(workdir, 'make_model.pkl'), 'wb') as fh:
            fh.write(dill.dumps(make_model))
        saver = tf.train.Saver(max_to_keep=10000000)

        def save_state(fname):
            if not osp.exists(osp.dirname(fname)):
                os.makedirs(osp.dirname(fname))
            saver.save(tf.get_default_session(), fname)

    with tf.device('/gpu:0'):
        model = make_model()
    if load_path is not None:
        model.load(load_path)
    runner = Runner(env=env,
                    model=model,
                    nsteps=nsteps,
                    gamma=gamma,
                    lam=lam,
                    eval_callback=eval_callback)

    if checkpoint_path_for_debugging is not None:
        tf_util.load_state(checkpoint_path_for_debugging,
                           var_list=tf.get_collection(
                               tf.GraphKeys.GLOBAL_VARIABLES,
                               scope='rlb_model'))

    epinfobuf = deque(maxlen=100)
    tfirststart = time.time()

    nupdates = total_timesteps // nbatch
    for update in range(1, nupdates + 1):
        assert nbatch % nminibatches == 0
        nbatch_train = nbatch // nminibatches
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)
        (obs, next_obs, returns, masks, actions, values,
         neglogpacs), states, epinfos, (rewards, rewards_ext, rewards_int,
                                        rewards_int_raw, selected_infos,
                                        dones) = runner.run()
        epinfobuf.extend(epinfos)
        mblossvals = []
        mbhistos = []
        mbscs = []

        #if model.all_rlb_args.debug_args['debug_tf_timeline'] and update % 5 == 0:
        if model.all_rlb_args.debug_args[
                'debug_tf_timeline'] and update % 1 == 0:
            debug_timeliner = logger.TimeLiner()
        else:
            debug_timeliner = None

        if states is None:  # nonrecurrent version
            inds = np.arange(nbatch)
            for oe in range(noptepochs):
                gather_histo = (oe == noptepochs - 1)
                np.random.shuffle(inds)
                for start in range(0, nbatch, nbatch_train):
                    gather_sc = ((oe == noptepochs - 1)
                                 and (start + nbatch_train >= nbatch))
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = [
                        arr[mbinds] for arr in (obs, returns, masks, actions,
                                                values, neglogpacs, next_obs)
                    ]
                    with logger.ProfileKV('train'):
                        fetches = model.train(lrnow,
                                              cliprangenow,
                                              slices[0],
                                              slices[6],
                                              slices[1],
                                              slices[2],
                                              slices[3],
                                              slices[4],
                                              slices[5],
                                              gather_histo=gather_histo,
                                              gather_sc=gather_sc,
                                              debug_timeliner=debug_timeliner)
                    mblossvals.append(fetches['losses'])
                    if gather_histo:
                        mbhistos.append(fetches['stats_histo'])
                    if gather_sc:
                        mbscs.append(fetches['stats_sc'])
        else:  # recurrent version
            assert nenvs % nminibatches == 0
            envsperbatch = nenvs // nminibatches
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            envsperbatch = nbatch_train // nsteps
            for oe in range(noptepochs):
                gather_histo = (oe == noptepochs - 1)
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    gather_sc = ((oe == noptepochs - 1)
                                 and (start + nbatch_train >= nbatch))
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = [
                        arr[mbflatinds]
                        for arr in (obs, returns, masks, actions, values,
                                    neglogpacs, next_obs)
                    ]
                    mbstates = states[mbenvinds]
                    fetches = model.train(lrnow,
                                          cliprangenow,
                                          slices[0],
                                          slices[6],
                                          slices[1],
                                          slices[2],
                                          slices[3],
                                          slices[4],
                                          slices[5],
                                          mbstates,
                                          gather_histo=gather_histo,
                                          gather_sc=gather_sc,
                                          debug_timeliner=debug_timeliner)
                    mblossvals.append(fetches['losses'])
                    if gather_histo:
                        mbhistos.append(fetches['stats_histo'])
                    if gather_sc:
                        mbscs.append(fetches['stats_sc'])

        if debug_timeliner is not None:
            with logger.ProfileKV("save_timeline_json"):
                debug_timeliner.save(
                    osp.join(workdir, 'timeline_{}.json'.format(update)))

        lossvals = np.mean(mblossvals, axis=0)
        assert len(mbscs) == 1
        scalars = mbscs[0]
        histograms = {
            n: np.concatenate([f[n] for f in mbhistos], axis=0)
            for n in model.stats_histo_names
        }
        logger.info('Histograms: {}'.format([(n, histograms[n].shape)
                                             for n in histograms.keys()]))
        #for v in histograms.values():
        #  assert len(v) == nbatch
        tnow = time.time()
        fps = int(nbatch / (tnow - tstart))
        if update % log_interval == 0 or update == 1:
            fps_total = int((update * nbatch) / (tnow - tfirststart))

            #tf_op_names = [i.name for i in tf.get_default_graph().get_operations()]
            #logger.info('#################### tf_op_names: {}'.format(tf_op_names))
            tf_num_ops = len(tf.get_default_graph().get_operations())
            logger.info(
                '#################### tf_num_ops: {}'.format(tf_num_ops))
            logger.logkv('tf_num_ops', tf_num_ops)
            ev = explained_variance(values, returns)
            logger.logkv('serial_timesteps', update * nsteps)
            logger.logkv('nupdates', update)
            logger.logkv('total_timesteps', update * nbatch)
            logger.logkv('fps', fps)
            logger.logkv('fps_total', fps_total)
            logger.logkv(
                'remaining_time',
                float(tnow - tfirststart) / float(update) *
                float(nupdates - update))
            logger.logkv('explained_variance', float(ev))
            logger.logkv('eprewmean',
                         safemean([epinfo['r'] for epinfo in epinfobuf]))
            logger.logkv('eplenmean',
                         safemean([epinfo['l'] for epinfo in epinfobuf]))
            if train_callback:
                train_callback(safemean([epinfo['l'] for epinfo in epinfobuf]),
                               safemean([epinfo['r'] for epinfo in epinfobuf]),
                               update * nbatch)
            logger.logkv('time_elapsed', tnow - tfirststart)
            for (lossval, lossname) in zip(lossvals, model.loss_names):
                logger.logkv(lossname, lossval)

            for n, v in scalars.items():
                logger.logkv(n, v)

            for n, v in histograms.items():
                logger.logkv(n, v)
                logger.logkv('mean_' + n, np.mean(v))
                logger.logkv('std_' + n, np.std(v))
                logger.logkv('max_' + n, np.max(v))
                logger.logkv('min_' + n, np.min(v))

            for n, v in locals().items():
                if n in ['rewards_int', 'rewards_int_raw']:
                    logger.logkv(n, v)
                if n in [
                        'rewards', 'rewards_ext', 'rewards_int',
                        'rewards_int_raw'
                ]:
                    logger.logkv('mean_' + n, np.mean(v))
                    logger.logkv('std_' + n, np.std(v))
                    logger.logkv('max_' + n, np.max(v))
                    logger.logkv('min_' + n, np.min(v))

            if model.rlb_model:
                if model.all_rlb_args.outer_args['rlb_normalize_ir']:
                    logger.logkv('rlb_ir_running_mean', runner.irff_rms.mean)
                    logger.logkv('rlb_ir_running_std',
                                 np.sqrt(runner.irff_rms.var))

            logger.dumpkvs()
        if (save_interval and (update % save_interval == 0 or update == 1)
                and workdir):
            checkdir = osp.join(workdir, 'checkpoints')
            if not tf.gfile.Exists(checkdir):
                tf.gfile.MakeDirs(checkdir)
            savepath = osp.join(checkdir, '%.5i' % update)
            print('Saving to', savepath)
            model.save(savepath)

            checkdir = osp.join(workdir, 'full_checkpoints')
            if not tf.gfile.Exists(checkdir):
                tf.gfile.MakeDirs(checkdir)
            savepath = osp.join(checkdir, '%.5i' % update)
            print('Saving to', savepath)
            save_state(savepath)
        if (cloud_sync_interval and update % cloud_sync_interval == 0
                and cloud_sync_callback):
            cloud_sync_callback()
    env.close()
    return model
Пример #6
0
  def train(self, batch_gen, steps_per_epoch, num_epochs):
    mblossvals = []
    mbhistos = []
    mbscs = []
    mbascs = []
    for epoch in range(num_epochs):
      gather_histo = (epoch == num_epochs - 1)
      for step in range(steps_per_epoch):
        gather_sc = ((epoch == num_epochs - 1) and (step == steps_per_epoch - 1))
        obs, obs_next, acs = next(batch_gen)
        with logger.ProfileKV('train_ot_inner'):
          fetches = self._train(
              obs, obs_next, acs,
              gather_histo=gather_histo, gather_sc=gather_sc)
        mblossvals.append(fetches['losses'])
        if gather_histo:
          mbhistos.append(fetches['stats_histo'])
        if gather_sc:
          mbscs.append(fetches['stats_sc'])
          mbascs.append(fetches['additional_sc'])

    lossvals = np.mean(mblossvals, axis=0)
    assert len(mbscs) == 1
    assert len(mbascs) == 1
    scalars = mbscs[0]
    additional_scalars = mbascs[0]
    histograms = { n: np.concatenate([f[n] for f in mbhistos], axis=0) for n in self._stats_histo_names }
    logger.info('RLBModelWrapper.train histograms: {}'.format([(n, histograms[n].shape) for n in histograms.keys()]))

    for (lossval, lossname) in zip(lossvals, self._loss_names):
      logger.logkv(lossname, lossval)

    for n, v in scalars.items():
      logger.logkv(n, v)

    for n, v in additional_scalars.items():
      logger.logkv(n, v)

    for n, v in histograms.items():
      logger.logkv(n, v)
      logger.logkv('mean_' + n, np.mean(v))
      logger.logkv('std_' + n, np.std(v))
      logger.logkv('max_' + n, np.max(v))
      logger.logkv('min_' + n, np.min(v))