Esempio n. 1
0
def learn(*,
          policy,
          env,
          nsteps,
          total_timesteps,
          ent_coef,
          lr,
          vf_coef=0.5,
          max_grad_norm=0.5,
          gamma=0.99,
          lam=0.95,
          log_interval=10,
          nminibatches=4,
          noptepochs=4,
          cliprange=0.2,
          save_interval=0,
          load_path=None):
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    mpi_size = comm.Get_size()

    sess = tf.get_default_session()
    tb_writer = TB_Writer(sess)

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps

    nbatch_train = nbatch // nminibatches

    model = Model(policy=policy,
                  ob_space=ob_space,
                  ac_space=ac_space,
                  nbatch_act=nenvs,
                  nbatch_train=nbatch_train,
                  nsteps=nsteps,
                  ent_coef=ent_coef,
                  vf_coef=vf_coef,
                  max_grad_norm=max_grad_norm)

    utils.load_all_params(sess)

    runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)

    epinfobuf10 = deque(maxlen=10)
    epinfobuf100 = deque(maxlen=100)
    tfirststart = time.time()
    active_ep_buf = epinfobuf100

    nupdates = total_timesteps // nbatch
    mean_rewards = []
    datapoints = []

    run_t_total = 0
    train_t_total = 0

    can_save = True
    checkpoints = [32, 64]
    saved_key_checkpoints = [False] * len(checkpoints)

    if Config.SYNC_FROM_ROOT and rank != 0:
        can_save = False

    def save_model(base_name=None):
        base_dict = {'datapoints': datapoints}
        utils.save_params_in_scopes(sess, ['model'],
                                    Config.get_save_file(base_name=base_name),
                                    base_dict)

    for update in range(1, nupdates + 1):
        assert nbatch % nminibatches == 0
        nbatch_train = nbatch // nminibatches
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)

        mpi_print('collecting rollouts...')
        run_tstart = time.time()

        obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run(
        )
        epinfobuf10.extend(epinfos)
        epinfobuf100.extend(epinfos)

        run_elapsed = time.time() - run_tstart
        run_t_total += run_elapsed
        mpi_print('rollouts complete')

        mblossvals = []

        mpi_print('updating parameters...')
        train_tstart = time.time()

        if states is None:  # nonrecurrent version
            inds = np.arange(nbatch)
            for _ in range(noptepochs):
                np.random.shuffle(inds)
                for start in range(0, nbatch, nbatch_train):
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (arr[mbinds]
                              for arr in (obs, returns, masks, actions, values,
                                          neglogpacs))
                    mblossvals.append(model.train(lrnow, cliprangenow,
                                                  *slices))
        else:  # recurrent version
            assert nenvs % nminibatches == 0
            envinds = np.arange(nenvs)
            flatinds = np.arange(nenvs * nsteps).reshape(nenvs, nsteps)
            envsperbatch = nbatch_train // nsteps
            for _ in range(noptepochs):
                np.random.shuffle(envinds)
                for start in range(0, nenvs, envsperbatch):
                    end = start + envsperbatch
                    mbenvinds = envinds[start:end]
                    mbflatinds = flatinds[mbenvinds].ravel()
                    slices = (arr[mbflatinds]
                              for arr in (obs, returns, masks, actions, values,
                                          neglogpacs))
                    mbstates = states[mbenvinds]
                    mblossvals.append(
                        model.train(lrnow, cliprangenow, *slices, mbstates))

        # update the dropout mask
        sess.run([model.train_model.dropout_assign_ops])

        train_elapsed = time.time() - train_tstart
        train_t_total += train_elapsed
        mpi_print('update complete')

        lossvals = np.mean(mblossvals, axis=0)
        tnow = time.time()
        fps = int(nbatch / (tnow - tstart))

        if update % log_interval == 0 or update == 1:
            step = update * nbatch
            rew_mean_10 = utils.process_ep_buf(active_ep_buf,
                                               tb_writer=tb_writer,
                                               suffix='',
                                               step=step)
            ep_len_mean = np.nanmean([epinfo['l'] for epinfo in active_ep_buf])

            mpi_print('\n----', update)

            mean_rewards.append(rew_mean_10)
            datapoints.append([step, rew_mean_10])

            tb_writer.log_scalar(ep_len_mean, 'ep_len_mean')
            tb_writer.log_scalar(fps, 'fps')

            mpi_print('time_elapsed', tnow - tfirststart, run_t_total,
                      train_t_total)
            mpi_print('timesteps', update * nsteps, total_timesteps)

            mpi_print('eplenmean', ep_len_mean)
            mpi_print('eprew', rew_mean_10)
            mpi_print('fps', fps)
            mpi_print('total_timesteps', update * nbatch)
            mpi_print([epinfo['r'] for epinfo in epinfobuf10])

            if len(mblossvals):
                for (lossval, lossname) in zip(lossvals, model.loss_names):
                    mpi_print(lossname, lossval)
                    tb_writer.log_scalar(lossval, lossname)
            mpi_print('----\n')

        if can_save:
            if save_interval and (update % save_interval == 0):
                save_model()

            for j, checkpoint in enumerate(checkpoints):
                if (not saved_key_checkpoints[j]) and (step >=
                                                       (checkpoint * 1e6)):
                    saved_key_checkpoints[j] = True
                    save_model(str(checkpoint) + 'M')

    save_model()

    env.close()
    return mean_rewards
Esempio n. 2
0
def learn(*,
          policy,
          env,
          eval_env,
          nsteps,
          total_timesteps,
          ent_coef,
          lr,
          vf_coef=0.5,
          max_grad_norm=0.5,
          gamma=0.99,
          lam=0.95,
          log_interval=10,
          nminibatches=4,
          noptepochs=4,
          cliprange=0.2,
          save_interval=0,
          load_path=None):
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    mpi_size = comm.Get_size()

    #tf.compat.v1.disable_v2_behavior()
    sess = tf.compat.v1.get_default_session()

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps

    nbatch_train = nbatch // nminibatches

    model = Model(policy=policy,
                  ob_space=ob_space,
                  ac_space=ac_space,
                  nbatch_act=nenvs,
                  nbatch_train=nbatch_train,
                  nsteps=nsteps,
                  ent_coef=ent_coef,
                  vf_coef=vf_coef,
                  max_grad_norm=max_grad_norm)

    utils.load_all_params(sess)

    runner = Runner(env=env,
                    model=model,
                    eval_env=eval_env,
                    nsteps=nsteps,
                    gamma=gamma,
                    lam=lam)

    epinfobuf10 = deque(maxlen=10)
    epinfobuf100 = deque(maxlen=100)
    eval_epinfobuf100 = deque(maxlen=100)
    tfirststart = time.time()
    active_ep_buf = epinfobuf100
    eval_active_ep_buf = eval_epinfobuf100

    nupdates = total_timesteps // nbatch
    mean_rewards = []
    datapoints = []

    run_t_total = 0
    train_t_total = 0

    can_save = True
    checkpoints = [32, 64]
    saved_key_checkpoints = [False] * len(checkpoints)

    if Config.SYNC_FROM_ROOT and rank != 0:
        can_save = False

    def save_model(base_name=None):
        base_dict = {'datapoints': datapoints}
        utils.save_params_in_scopes(sess, ['model'],
                                    Config.get_save_file(base_name=base_name),
                                    base_dict)

    # For logging purposes, allow restoring of update
    start_update = 0
    if Config.RESTORE_STEP is not None:
        start_update = Config.RESTORE_STEP // nbatch

    tb_writer = TB_Writer(sess)
    import os
    os.environ["WANDB_API_KEY"] = "02e3820b69de1b1fcc645edcfc3dd5c5079839a1"
    group_name = "%s__%s__%f" % (Config.ENVIRONMENT, Config.AGENT,
                                 Config.REP_LOSS_WEIGHT)
    wandb.init(project='procgen_generalization',
               entity='ssl_rl',
               config=Config.args_dict,
               group=group_name,
               mode="disabled" if Config.DISABLE_WANDB else "online")
    for update in range(start_update + 1, nupdates + 1):
        assert nbatch % nminibatches == 0
        nbatch_train = nbatch // nminibatches
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)

        mpi_print('collecting rollouts...')
        run_tstart = time.time()

        packed = runner.run(update_frac=update / nupdates)
        obs, returns, returns_i, masks, actions, values, values_i, neglogpacs, infos, states_nce, anchors_nce, labels_nce, epinfos, eval_epinfos = packed

        # reshape our augmented state vectors to match first dim of observation array
        # (mb_size*num_envs, 64*64*RGB)
        # (mb_size*num_envs, num_actions)
        avg_value = np.mean(values)
        epinfobuf10.extend(epinfos)
        epinfobuf100.extend(epinfos)
        eval_epinfobuf100.extend(eval_epinfos)

        run_elapsed = time.time() - run_tstart
        run_t_total += run_elapsed
        mpi_print('rollouts complete')

        mblossvals = []

        mpi_print('updating parameters...')
        train_tstart = time.time()

        mean_cust_loss = 0
        inds = np.arange(nbatch)
        for _ in range(noptepochs):
            np.random.shuffle(inds)
            for start in range(0, nbatch, nbatch_train):
                sess.run([model.train_model.train_dropout_assign_ops])
                end = start + nbatch_train
                mbinds = inds[start:end]
                if Config.CUSTOM_REP_LOSS:
                    slices = (arr[mbinds] for arr in (obs, returns, returns_i,
                                                      masks, actions, values,
                                                      values_i, neglogpacs))
                else:
                    # since we don't use phi_bars, use obs as dummy variable
                    dummy = obs
                    slices = (arr[mbinds] for arr in (obs, returns, returns_i,
                                                      masks, actions, values,
                                                      values_i, neglogpacs))

                mblossvals.append(
                    model.train(lrnow, cliprangenow, states_nce, anchors_nce,
                                labels_nce, *slices))
        # update the dropout mask
        sess.run([model.train_model.train_dropout_assign_ops])
        sess.run([model.train_model.run_dropout_assign_ops])

        train_elapsed = time.time() - train_tstart
        train_t_total += train_elapsed
        mpi_print('update complete')

        lossvals = np.mean(mblossvals, axis=0)
        tnow = time.time()
        fps = int(nbatch / (tnow - tstart))

        if update % log_interval == 0 or update == 1:
            step = update * nbatch
            eval_rew_mean = utils.process_ep_buf(eval_active_ep_buf,
                                                 tb_writer=tb_writer,
                                                 suffix='_eval',
                                                 step=step)
            rew_mean_10 = utils.process_ep_buf(active_ep_buf,
                                               tb_writer=tb_writer,
                                               suffix='',
                                               step=step)
            ep_len_mean = np.nanmean([epinfo['l'] for epinfo in active_ep_buf])

            mpi_print('\n----', update)

            mean_rewards.append(rew_mean_10)
            datapoints.append([step, rew_mean_10])
            tb_writer.log_scalar(ep_len_mean, 'ep_len_mean', step=step)
            tb_writer.log_scalar(fps, 'fps', step=step)
            tb_writer.log_scalar(avg_value, 'avg_value', step=step)
            tb_writer.log_scalar(mean_cust_loss, 'custom_loss', step=step)

            mpi_print('time_elapsed', tnow - tfirststart, run_t_total,
                      train_t_total)
            mpi_print('timesteps', update * nsteps, total_timesteps)

            mpi_print('eplenmean', ep_len_mean)
            mpi_print('eprew', rew_mean_10)
            mpi_print('eprew_eval', eval_rew_mean)
            mpi_print('fps', fps)
            mpi_print('total_timesteps', update * nbatch)
            mpi_print([epinfo['r'] for epinfo in epinfobuf10])

            rep_loss = 0
            if len(mblossvals):
                for (lossval, lossname) in zip(lossvals, model.loss_names):
                    if lossname == 'rep_loss':
                        rep_loss = lossval
                    mpi_print(lossname, lossval)
                    tb_writer.log_scalar(lossval, lossname, step=step)
            mpi_print('----\n')

            wandb.log({
                "%s/ep_len_mean" % (Config.ENVIRONMENT): ep_len_mean,
                "%s/avg_value" % (Config.ENVIRONMENT): avg_value,
                "%s/custom_loss" % (Config.ENVIRONMENT): mean_cust_loss,
                "%s/eplenmean" % (Config.ENVIRONMENT): ep_len_mean,
                "%s/eprew" % (Config.ENVIRONMENT): rew_mean_10,
                "%s/eprew_eval" % (Config.ENVIRONMENT): eval_rew_mean,
                "%s/rep_loss" % (Config.ENVIRONMENT): rep_loss,
                "%s/custom_step" % (Config.ENVIRONMENT): step
            })
        if can_save:
            if save_interval and (update % save_interval == 0):
                save_model()

            for j, checkpoint in enumerate(checkpoints):
                if (not saved_key_checkpoints[j]) and (step >=
                                                       (checkpoint * 1e6)):
                    saved_key_checkpoints[j] = True
                    save_model(str(checkpoint) + 'M')

    save_model()

    env.close()
    return mean_rewards
Esempio n. 3
0
def learn(*,
          policy,
          env,
          eval_env,
          nsteps,
          total_timesteps,
          ent_coef,
          lr,
          vf_coef=0.5,
          max_grad_norm=0.5,
          gamma=0.99,
          lam=0.95,
          log_interval=10,
          nminibatches=4,
          noptepochs=4,
          cliprange=0.2,
          save_interval=0,
          load_path=None):
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    mpi_size = comm.Get_size()

    #tf.compat.v1.disable_v2_behavior()
    sess = tf.compat.v1.get_default_session()

    if isinstance(lr, float): lr = constfn(lr)
    else: assert callable(lr)
    if isinstance(cliprange, float): cliprange = constfn(cliprange)
    else: assert callable(cliprange)
    total_timesteps = int(total_timesteps)

    nenvs = env.num_envs
    ob_space = env.observation_space
    ac_space = env.action_space
    nbatch = nenvs * nsteps

    nbatch_train = nbatch // nminibatches
    model = Model(policy=policy,
                  ob_space=ob_space,
                  ac_space=ac_space,
                  nbatch_act=nenvs,
                  nbatch_train=nbatch_train,
                  nsteps=nsteps,
                  ent_coef=ent_coef,
                  vf_coef=vf_coef,
                  max_grad_norm=max_grad_norm)
    # with tf.compat.v1.variable_scope('model_1'):
    # model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
    # 				nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
    # 				max_grad_norm=max_grad_norm)

    # model_1_vars  = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='model_1'):
    # model_saver = tf.compat.v1.train.Saver(var_list=model_1_vars)
    # saver to save and load models
    # import ipdb;ipdb.set_trace()
    model_saver = tf.compat.v1.train.Saver()
    if Config.RESTORE_IDD is not None:
        mpi_print('Restoring model...')
        model_saver.restore(sess,
                            save_path='{}{}-1'.format(Config.RESTORE_PATH,
                                                      Config.RESTORE_IDD))
    # TODO(Ahmed) remove coinrun saving code if default TF methods work
    #utils.load_all_params(sess)

    runner = Runner(env=env,
                    eval_env=eval_env,
                    model=model,
                    nsteps=nsteps,
                    gamma=gamma,
                    lam=lam)

    epinfobuf10 = deque(maxlen=10)
    epinfobuf100 = deque(maxlen=100)
    eval_epinfobuf100 = deque(maxlen=100)
    tfirststart = time.time()
    active_ep_buf = epinfobuf100
    eval_active_ep_buf = eval_epinfobuf100

    nupdates = total_timesteps // nbatch
    mean_rewards = []
    datapoints = []

    run_t_total = 0
    train_t_total = 0

    can_save = False
    checkpoints = [32, 64]
    saved_key_checkpoints = [False] * len(checkpoints)

    if Config.SYNC_FROM_ROOT and rank != 0:
        can_save = False

    # TODO(Ahmed) remove coinrun saving code if default TF methods work
    # def save_model(base_name=None):
    # 	base_dict = {'datapoints': datapoints}
    # 	utils.save_params_in_scopes(sess, ['model'], Config.get_save_file(base_name=base_name), base_dict)

    # For logging purposes, allow restoring of update
    start_update = 0
    if Config.RESTORE_STEP is not None:
        start_update = Config.RESTORE_STEP // nbatch

    z_iter = 0
    curr_z = np.random.randint(0, high=Config.POLICY_NHEADS)
    tb_writer = TB_Writer(sess)
    import os
    os.environ["WANDB_API_KEY"] = "02e3820b69de1b1fcc645edcfc3dd5c5079839a1"
    group_name = "%s__%s__%s" % (Config.ENVIRONMENT, Config.AGENT,
                                 Config.RUN_ID)
    run_name = "%s__%s__%s__%d" % (Config.ENVIRONMENT, Config.AGENT,
                                   Config.RUN_ID, Config.START_LEVEL)
    wandb.init(project='procgen_generalization',
               entity='ssl_rl',
               config=Config.args_dict,
               group=group_name,
               name=run_name,
               mode="disabled" if Config.DISABLE_WANDB else "online")

    print('Tf using GPU:')
    print(tf.test.is_built_with_cuda())
    for update in range(start_update + 1, nupdates + 1):
        assert nbatch % nminibatches == 0
        nbatch_train = nbatch // nminibatches
        tstart = time.time()
        frac = 1.0 - (update - 1.0) / nupdates
        lrnow = lr(frac)
        cliprangenow = cliprange(frac)

        # if Config.CUSTOM_REP_LOSS:
        #     params = tf.compat.v1.trainable_variables()
        #     source_params = [p for p in params if p.name in model.train_model.RL_enc_param_names]
        #     for i in range(1,Config.POLICY_NHEADS):
        #         target_i_params = [p for p in params if p.name in model.train_model.target_enc_param_names[i]]
        #         soft_update(source_params,target_i_params,tau=0.95)

        mpi_print('collecting rollouts...')
        run_tstart = time.time()
        # if z_iter < 4: # 8 epochs / skill
        #     z_iter += 1
        # else:
        #     # sample new skill for current episodes
        #     curr_z = np.random.randint(0, high=Config.POLICY_NHEADS)
        #     model.head_idx_current_batch = curr_z
        #     z_iter = 0

        packed = runner.run(update_frac=update / nupdates)
        if Config.CUSTOM_REP_LOSS:
            obs, returns, masks, actions, values, neglogpacs, infos, values_i, returns_i, states_nce, anchors_nce, labels_nce, actions_nce, neglogps_nce, rewards_nce, infos_nce, epinfos, eval_epinfos = packed
        else:
            obs, returns, masks, actions, values, neglogpacs, infos, epinfos, eval_epinfos = packed
            values_i = returns_i = states_nce = anchors_nce = labels_nce = actions_nce = neglogps_nce = rewards_nce = infos_nce = None

        # reshape our augmented state vectors to match first dim of observation array
        # (mb_size*num_envs, 64*64*RGB)
        # (mb_size*num_envs, num_actions)
        avg_value = np.mean(values)
        epinfobuf10.extend(epinfos)
        epinfobuf100.extend(epinfos)
        eval_epinfobuf100.extend(eval_epinfos)

        run_elapsed = time.time() - run_tstart
        run_t_total += run_elapsed
        mpi_print('rollouts complete')

        mblossvals = []

        mpi_print('updating parameters...')
        train_tstart = time.time()

        mean_cust_loss = 0
        inds = np.arange(nbatch)
        inds_nce = np.arange(nbatch // runner.nce_update_freq)
        for _ in range(noptepochs):
            np.random.shuffle(inds)
            np.random.shuffle(inds_nce)
            for start in range(0, nbatch, nbatch_train):
                sess.run([model.train_model.train_dropout_assign_ops])
                end = start + nbatch_train
                mbinds = inds[start:end]
                mbinds_nce = inds_nce[start //
                                      runner.nce_update_freq:(start +
                                                              nbatch_train) //
                                      runner.nce_update_freq]

                slices = (arr[mbinds] for arr in (obs, returns, masks, actions,
                                                  infos, values, neglogpacs))
                if Config.CUSTOM_REP_LOSS:
                    slices_nce = (arr[mbinds_nce]
                                  for arr in (values_i, returns_i, states_nce,
                                              anchors_nce, labels_nce,
                                              actions_nce, neglogps_nce,
                                              rewards_nce, infos_nce))
                else:
                    slices_nce = (arr
                                  for arr in (values_i, returns_i, states_nce,
                                              anchors_nce, labels_nce,
                                              actions_nce, neglogps_nce,
                                              rewards_nce, infos_nce))

                mblossvals.append(
                    model.train(lrnow, cliprangenow, *slices, *slices_nce))
        # update the dropout mask
        sess.run([model.train_model.train_dropout_assign_ops])
        sess.run([model.train_model.run_dropout_assign_ops])

        train_elapsed = time.time() - train_tstart
        train_t_total += train_elapsed
        mpi_print('update complete')

        lossvals = np.mean(mblossvals, axis=0)
        tnow = time.time()
        fps = int(nbatch / (tnow - tstart))

        if update % log_interval == 0 or update == 1:
            step = update * nbatch
            eval_rew_mean = utils.process_ep_buf(eval_active_ep_buf,
                                                 tb_writer=tb_writer,
                                                 suffix='_eval',
                                                 step=step)
            rew_mean_10 = utils.process_ep_buf(active_ep_buf,
                                               tb_writer=tb_writer,
                                               suffix='',
                                               step=step)

            ep_len_mean = np.nanmean([epinfo['l'] for epinfo in active_ep_buf])

            mpi_print('\n----', update)

            mean_rewards.append(rew_mean_10)
            datapoints.append([step, rew_mean_10])
            tb_writer.log_scalar(ep_len_mean, 'ep_len_mean', step=step)
            tb_writer.log_scalar(fps, 'fps', step=step)
            tb_writer.log_scalar(avg_value, 'avg_value', step=step)
            tb_writer.log_scalar(mean_cust_loss, 'custom_loss', step=step)

            mpi_print('time_elapsed', tnow - tfirststart, run_t_total,
                      train_t_total)
            mpi_print('timesteps', update * nsteps, total_timesteps)

            # eval_rew_mean = episode_rollouts(eval_env,model,step,tb_writer)

            mpi_print('eplenmean', ep_len_mean)
            mpi_print('eprew', rew_mean_10)
            mpi_print('eprew_eval', eval_rew_mean)
            mpi_print('fps', fps)
            mpi_print('total_timesteps', update * nbatch)
            mpi_print([epinfo['r'] for epinfo in epinfobuf10])

            rep_loss = 0
            if len(mblossvals):
                for (lossval, lossname) in zip(lossvals, model.loss_names):
                    mpi_print(lossname, lossval)
                    tb_writer.log_scalar(lossval, lossname, step=step)
            mpi_print('----\n')

            wandb.log({
                "%s/eprew" % (Config.ENVIRONMENT): rew_mean_10,
                "%s/eprew_eval" % (Config.ENVIRONMENT): eval_rew_mean,
                "%s/custom_step" % (Config.ENVIRONMENT): step
            })
        # if update == 1:
        # 	quit()
        # 	break
        # can_save = True
        # save_interval = 10
        # if can_save:
        # 	if save_interval and (update % save_interval == 0):
        # 		save_model()

        # for j, checkpoint in enumerate(checkpoints):
        # 	if (not saved_key_checkpoints[j]) and (step >= (checkpoint * 1e6)):
        # 		saved_key_checkpoints[j] = True
        # 		save_model(str(checkpoint) + 'M')

    # save_model()

    # save model at the end of training loop
    save_model_path = '%s-%s-%s' % (Config.RESTORE_PATH, Config.RUN_ID,
                                    Config.START_LEVEL)
    true_path = model_saver.save(sess,
                                 save_path=save_model_path + '/ppo',
                                 global_step=1)
    wandb.save(save_model_path + '/*')
    mpi_print('true path for checkpoint', true_path)
    env.close()
    return mean_rewards
Esempio n. 4
0
def learn(*, policy, env, eval_env, nsteps, total_timesteps, ent_coef, lr,
			 vf_coef=0.5,  max_grad_norm=0.5, gamma=0.99, lam=0.95,
			log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
			save_interval=0, load_path=None):
	comm = MPI.COMM_WORLD
	rank = comm.Get_rank()
	mpi_size = comm.Get_size()

	#tf.compat.v1.disable_v2_behavior()
	sess = tf.compat.v1.get_default_session()

	if isinstance(lr, float): lr = constfn(lr)
	else: assert callable(lr)
	if isinstance(cliprange, float): cliprange = constfn(cliprange)
	else: assert callable(cliprange)
	total_timesteps = int(total_timesteps)
	
	nenvs = env.num_envs
	ob_space = env.observation_space
	ac_space = env.action_space
	nbatch = nenvs * nsteps
	
	nbatch_train = nbatch // nminibatches
	model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nbatch_act=nenvs, nbatch_train=nbatch_train,
					nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
					max_grad_norm=max_grad_norm)
	utils.load_all_params(sess)

	runner = Runner(env=env, eval_env=eval_env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)

	epinfobuf10 = deque(maxlen=10)
	epinfobuf100 = deque(maxlen=100)
	eval_epinfobuf100 = deque(maxlen=100)
	tfirststart = time.time()
	active_ep_buf = epinfobuf100
	eval_active_ep_buf = eval_epinfobuf100

	nupdates = total_timesteps//nbatch
	mean_rewards = []
	datapoints = []

	run_t_total = 0
	train_t_total = 0

	can_save = False
	checkpoints = [32, 64]
	saved_key_checkpoints = [False] * len(checkpoints)

	if Config.SYNC_FROM_ROOT and rank != 0:
		can_save = False

	def save_model(base_name=None):
		base_dict = {'datapoints': datapoints}
		utils.save_params_in_scopes(sess, ['model'], Config.get_save_file(base_name=base_name), base_dict)

	# For logging purposes, allow restoring of update
	start_update = 0
	if Config.RESTORE_STEP is not None:
		start_update = Config.RESTORE_STEP // nbatch

	z_iter = 0
	curr_z = np.random.randint(0, high=Config.POLICY_NHEADS)
	tb_writer = TB_Writer(sess)
	import os
	os.environ["WANDB_API_KEY"] = "02e3820b69de1b1fcc645edcfc3dd5c5079839a1"
	os.environ["WANDB_SILENT"] = "true"
	run_id = np.random.randint(100000000)
	os.environ["WANDB_RUN_ID"] = str(run_id)
	group_name = "%s__%s__%f__%f" %(Config.ENVIRONMENT,Config.RUN_ID,Config.REP_LOSS_WEIGHT, Config.TEMP)
	name = "%s__%s__%f__%f__%d" %(Config.ENVIRONMENT,Config.RUN_ID,Config.REP_LOSS_WEIGHT, Config.TEMP, run_id)
	wandb.init(project='ising_generalization' if Config.ENVIRONMENT == 'ising' else 'procgen_generalization' ,
			  entity='ssl_rl', config=Config.args_dict,
			  group=group_name, name=name,
			  mode="disabled" if Config.DISABLE_WANDB else "online")

	api = wandb.Api()
	list_runs = api.runs("ssl_rl/procgen_generalization")
	single_level_runs=[run for run in list_runs if 'ppo_per_level' in run.name]
	non_crashed = [run for run in single_level_runs if run.state in ['running','finished']]
	game_runs = [run for run in non_crashed if Config.ENVIRONMENT in run.name]
	wandb_save_dir = '%s/%s'%(Config.RESTORE_PATH,Config.ENVIRONMENT)
	print('Save dir: %s'%wandb_save_dir)
	if not os.path.isdir(wandb_save_dir):
		import requests
		for run in game_runs:
			level_id = run.name.split('__')[-1]
			run_save_dir = wandb_save_dir + '/' + level_id
			if not os.path.isdir(run_save_dir):
				os.makedirs(run_save_dir)

			def save_wandb_file(name):
				url = "https://api.wandb.ai/files/ssl_rl/procgen_generalization/%s/%s"%(run.id,name)
				r = requests.get(url)
				with open(run_save_dir+'/%s'%name , 'wb') as fh:
					fh.write(r.content)

			save_wandb_file('checkpoint')
			save_wandb_file('ppo-1.data-00000-of-00001')
			save_wandb_file('ppo-1.index')
			save_wandb_file('ppo-1.meta')

			print('Downloaded level id %s to %s (run id: %s)' % (level_id,run_save_dir,run.id) )
			print(os.listdir(run_save_dir))
			# wandb.restore(wandb_save_dir+"/checkpoint",run_path='/'.join(run.path))

	# load in just the graph and model parameters outside for-loop
	from coinrun import policies as policies_ppo
	ppo = policies_ppo.get_policy()
	ppo_graph_1, ppo_graph_2 = tf.Graph(), tf.Graph()

	PSE_policy = Config.PSE_POLICY

	if PSE_policy == 'ppo_2':
		levels = np.unique(os.listdir(wandb_save_dir)).astype(int)
		if Config.ENVIRONMENT == 'bigfish':
			levels = np.setdiff1d(levels,np.array([4]))

		pse_replay = []
		for mdp_id in levels:
			print('Collecting MDP %d'%mdp_id)
			mb_obs_i, mb_actions_i, mb_rewards_i = generate_level_replay(ppo,mdp_id,wandb_save_dir,nbatch_train, nsteps, max_grad_norm, ob_space, ac_space, nsteps_rollout=782)
			pse_replay.append([mb_obs_i, mb_actions_i, mb_rewards_i])

		
	for update in range(start_update+1, nupdates+1):
		assert nbatch % nminibatches == 0
		nbatch_train = nbatch // nminibatches
		tstart = time.time()
		frac = 1.0 - (update - 1.0) / nupdates
		lrnow = lr(frac)
		cliprangenow = cliprange(frac)

		# mpi_print('collecting rollouts...')
		run_tstart = time.time()

		packed = runner.run(update_frac=update/nupdates)
	
		obs, returns, masks, actions, values, neglogpacs, infos, rewards, epinfos, eval_epinfos = packed
		values_i = returns_i = states_nce = anchors_nce = labels_nce = actions_nce = neglogps_nce = rewards_nce = infos_nce = None

		"""
		PSE data re-collection

		1. Make 2 envs for respective policies for 2 random levels
		"""
		
		levels = np.unique(os.listdir(wandb_save_dir)).astype(int)
		if Config.ENVIRONMENT == 'bigfish':
			levels = np.setdiff1d(levels,np.array([4]))
		mdp_1,mdp_2 = np.random.choice(levels,size=2,replace=False)
		# import ipdb;ipdb.set_trace()
		observation_space = Dict(rgb=Box(shape=(64,64,3),low=0,high=255))
		action_space = DiscreteG(15)

		gym3_env_eval_1 = ProcgenGym3Env(num=Config.NUM_ENVS, env_name=Config.ENVIRONMENT, num_levels=1, start_level=int(mdp_1), paint_vel_info=Config.PAINT_VEL_INFO, distribution_mode=Config.FIRST_PHASE)
		venv_eval_1 = FakeEnv(gym3_env_eval_1, observation_space, action_space)
		venv_eval_1 = VecExtractDictObs(venv_eval_1, "rgb")
		venv_eval_1 = VecMonitor(
			venv=venv_eval_1, filename=None, keep_buf=100,
		)
		venv_eval_1 = VecNormalize(venv=venv_eval_1, ob=False)
		venv_eval_1 = wrappers.add_final_wrappers(venv_eval_1)

		gym3_env_eval_2 = ProcgenGym3Env(num=Config.NUM_ENVS, env_name=Config.ENVIRONMENT, num_levels=1, start_level=int(mdp_2), paint_vel_info=Config.PAINT_VEL_INFO, distribution_mode=Config.FIRST_PHASE)
		venv_eval_2 = FakeEnv(gym3_env_eval_2, observation_space, action_space)
		venv_eval_2 = VecExtractDictObs(venv_eval_2, "rgb")
		venv_eval_2 = VecMonitor(
			venv=venv_eval_2, filename=None, keep_buf=100,
		)
		venv_eval_2 = VecNormalize(venv=venv_eval_2, ob=False)
		venv_eval_2 = wrappers.add_final_wrappers(venv_eval_2)

		def random_policy(states):
			actions = np.random.randint(0,15,Config.NUM_ENVS)
			return actions

		# print('Loading weights from %s'%(wandb_save_dir+'/%d/ppo-1'%mdp_1))
		# with ppo_graph.as_default():
		#     ppo_model = ppo(sess, ob_space, ac_space, nbatch_train, nsteps, max_grad_norm, override_agent='ppo')
		#import ipdb;ipdb.set_trace()
		# NOTE: this is recreating a graph within the updates, I'm moving them outside the training loop

		if PSE_policy == 'ppo':
			print('Using pretrained PPO policy')
			model1_path = wandb_save_dir+'/%d/ppo-1'%mdp_1
			model2_path = wandb_save_dir+'/%d/ppo-1'%mdp_2
			graph_one_vars = ppo_graph_1.get_all_collection_keys()

			with tf.compat.v1.Session(graph=ppo_graph_1,config=tf.ConfigProto(inter_op_parallelism_threads=1,intra_op_parallelism_threads=1)) as sess_1:
				with tf.compat.v1.variable_scope("model_1"):
					ppo_model_1 = ppo(sess_1, ob_space, ac_space, nbatch_train, nsteps, max_grad_norm, override_agent='ppo')
					initialize = tf.compat.v1.global_variables_initializer()
					sess_1.run(initialize)
				model_saver = tf.train.import_meta_graph(model1_path+'.meta')
				model_saver.restore(sess_1, save_path=model1_path)
				mb_obs_1, mb_actions_1, mb_rewards_1 = collect_data(ppo_model_1,venv_eval_1,nsteps=32, param_vals='pretrained')

			with tf.compat.v1.Session(graph=ppo_graph_2,config=tf.ConfigProto(inter_op_parallelism_threads=1,intra_op_parallelism_threads=1)) as sess_2:
				with tf.compat.v1.variable_scope("model_2"):
					ppo_model_2 = ppo(sess_2, ob_space, ac_space, nbatch_train, nsteps, max_grad_norm, override_agent='ppo')
					initialize = tf.compat.v1.global_variables_initializer()
					sess_2.run(initialize)
				model_saver = tf.train.import_meta_graph(model2_path+'.meta')
				model_saver.restore(sess_2, save_path=model2_path)

				mb_obs_2, mb_actions_2, mb_rewards_2 = collect_data(ppo_model_2,venv_eval_2,nsteps=32, param_vals='pretrained')
		elif PSE_policy == 'random':
			print('Using random uniform policy')
			mb_obs_1, mb_actions_1, mb_rewards_1 = collect_data(random_policy,venv_eval_1,nsteps=32, param_vals='random')
			mb_obs_2, mb_actions_2, mb_rewards_2 = collect_data(random_policy,venv_eval_2,nsteps=32, param_vals='random')
		elif PSE_policy == 'ppo_2':
			mdp_1,mdp_2 = np.random.choice(np.arange(len(pse_replay)),size=2,replace=False)
			mb_obs_1, mb_actions_1, mb_rewards_1 = pse_replay[mdp_1]
			mb_obs_2, mb_actions_2, mb_rewards_2 = pse_replay[mdp_2]
		# reshape our augmented state vectors to match first dim of observation array
		# (mb_size*num_envs, 64*64*RGB)
		# (mb_size*num_envs, num_actions)
		avg_value = np.mean(values)
		epinfobuf10.extend(epinfos)
		epinfobuf100.extend(epinfos)
		eval_epinfobuf100.extend(eval_epinfos)

		run_elapsed = time.time() - run_tstart
		run_t_total += run_elapsed
		# mpi_print('rollouts complete')

		mblossvals = []

		# mpi_print('updating parameters...')
		train_tstart = time.time()

		mean_cust_loss = 0
		inds = np.arange(nbatch)
		inds_pse = np.arange(1024)
		inds_nce = np.arange(nbatch//runner.nce_update_freq)
		for _ in range(noptepochs):
			np.random.shuffle(inds)
			np.random.shuffle(inds_nce)
			for start in range(0, nbatch, nbatch_train):
				sess.run([model.train_model.train_dropout_assign_ops])
				end = start + nbatch_train
				mbinds = inds[start:end]

				
				slices = (arr[mbinds] for arr in (obs, returns, masks, actions, infos, values, neglogpacs, rewards))

				slices_pse_1 = (arr[inds_pse] for arr in (mb_obs_1, mb_actions_1, mb_rewards_1))
				slices_pse_2 = (arr[inds_pse] for arr in (mb_obs_2, mb_actions_2, mb_rewards_2))
				
				mblossvals.append(model.train(lrnow, cliprangenow, *slices, *slices_pse_1, *slices_pse_2, train_target='policy'))

				slices = (arr[mbinds] for arr in (obs, returns, masks, actions, infos, values, neglogpacs, rewards))

			np.random.shuffle(inds_pse)
			slices_pse_1 = (arr[inds_pse] for arr in (mb_obs_1, mb_actions_1, mb_rewards_1))
			slices_pse_2 = (arr[inds_pse] for arr in (mb_obs_2, mb_actions_2, mb_rewards_2))
            
			model.train(lrnow, cliprangenow, *slices, *slices_pse_1, *slices_pse_2, train_target='pse')
		# update the dropout mask
		sess.run([model.train_model.train_dropout_assign_ops])
		sess.run([model.train_model.run_dropout_assign_ops])

		train_elapsed = time.time() - train_tstart
		train_t_total += train_elapsed
		# mpi_print('update complete')

		lossvals = np.mean(mblossvals, axis=0)
		tnow = time.time()
		fps = int(nbatch / (tnow - tstart))

		if update % log_interval == 0 or update == 1:
			step = update*nbatch
			eval_rew_mean = utils.process_ep_buf(eval_active_ep_buf, tb_writer=tb_writer, suffix='_eval', step=step)
			rew_mean_10 = utils.process_ep_buf(active_ep_buf, tb_writer=tb_writer, suffix='', step=step)
			
			ep_len_mean = np.nanmean([epinfo['l'] for epinfo in active_ep_buf])
			
			mpi_print('\n----', update)

			mean_rewards.append(rew_mean_10)
			datapoints.append([step, rew_mean_10])
			tb_writer.log_scalar(ep_len_mean, 'ep_len_mean', step=step)
			tb_writer.log_scalar(fps, 'fps', step=step)
			tb_writer.log_scalar(avg_value, 'avg_value', step=step)
			tb_writer.log_scalar(mean_cust_loss, 'custom_loss', step=step)


			mpi_print('time_elapsed', tnow - tfirststart, run_t_total, train_t_total)
			mpi_print('timesteps', update*nsteps, total_timesteps)

			# eval_rew_mean = episode_rollouts(eval_env,model,step,tb_writer)

			mpi_print('eplenmean', ep_len_mean)
			mpi_print('eprew', rew_mean_10)
			mpi_print('eprew_eval', eval_rew_mean)
			mpi_print('fps', fps)
			mpi_print('total_timesteps', update*nbatch)
			mpi_print([epinfo['r'] for epinfo in epinfobuf10])

			rep_loss = 0
			if len(mblossvals):
				for (lossval, lossname) in zip(lossvals, model.loss_names):
					mpi_print(lossname, lossval)
					tb_writer.log_scalar(lossval, lossname, step=step)
			mpi_print('----\n')

			wandb.log({"%s/eprew"%(Config.ENVIRONMENT):rew_mean_10,
						"%s/eprew_eval"%(Config.ENVIRONMENT):eval_rew_mean,
						"%s/custom_step"%(Config.ENVIRONMENT):step})
		if can_save:
			if save_interval and (update % save_interval == 0):
				save_model()

			for j, checkpoint in enumerate(checkpoints):
				if (not saved_key_checkpoints[j]) and (step >= (checkpoint * 1e6)):
					saved_key_checkpoints[j] = True
					save_model(str(checkpoint) + 'M')

	save_model()

	env.close()
	# import subprocess
	# wandb_files = os.listdir('wandb')
	# file_to_save = ''
	# for fn in wandb_files:
	# 	if str(run_id) in fn:
	# 		file_to_save = fn
	# 		break
	# print(file_to_save)
	# my_env = os.environ.copy()
	# my_env["WANDB_API_KEY"] = "02e3820b69de1b1fcc645edcfc3dd5c5079839a1"
	# subprocess.call(['wandb','sync','wandb/'+ file_to_save],env=my_env)
	return mean_rewards
Esempio n. 5
0
def main():
    # general setup

    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

    args = setup_utils.setup_and_load()

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()

    seed = int(time.time()) % 10000
    set_global_seeds(seed * 100 + rank)

    utils.setup_mpi_gpus()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # pylint: disable=E1101

    # perpare directory
    sub_dir = utils.file_to_path(Config.get_save_file(base_name="tmp"))
    if os.path.isdir(sub_dir):
        shutil.rmtree(path=sub_dir)
    os.mkdir(sub_dir)

    # hyperparams
    nenvs = Config.NUM_ENVS
    total_timesteps = Config.TIMESTEPS
    population_size = Config.POPULATION_SIZE
    timesteps_per_agent = Config.TIMESTEPS_AGENT
    worker_count = Config.WORKER_COUNT
    passthrough_perc = Config.PASSTHROUGH_PERC
    mutating_perc = Config.MUTATING_PERC

    # create environment
    def make_env():
        env = utils.make_general_env(nenvs, seed=rank)
        env = wrappers.add_final_wrappers(env)
        return env

    # setup session and workers, and therefore tensorflow ops
    graph = tf.get_default_graph()
    sess = tf.Session(graph=graph)

    policy = policies.get_policy()

    workers = [
        Worker(sess, i, nenvs, make_env, policy, sub_dir)
        for i in range(worker_count)
    ]

    tb_writer = TB_Writer(sess)

    def clean_exit():

        for worker in workers:
            Thread.join(worker.thread)

        utils.mpi_print("")
        utils.mpi_print("== total duration",
                        "{:.1f}".format(time.time() - t_first_start), " s ==")
        utils.mpi_print(" exit...")

        # save best performing agent
        population.sort(key=lambda k: k['fit'], reverse=True)
        workers[0].restore_model(name=population[0]["name"])
        workers[0].dump_model()

        # cleanup
        sess.close()
        shutil.rmtree(path=sub_dir)

    # load data from restore point and seed the whole population
    loaded_name = None
    if workers[0].try_load_model():
        loaded_name = str(uuid.uuid1())
        workers[0].save_model(name=loaded_name)

    # initialise population
    # either all random and no mutations pending
    # or all from restore point with all but one to be mutated
    population = [{
        "name": loaded_name or str(uuid.uuid1()),
        "fit": -1,
        "need_mut": loaded_name != None and i != 0,
        "age": -1,
        "mean_ep_len": -1
    } for i in range(population_size)]

    utils.mpi_print("== population size", population_size, ", t_agent ",
                    timesteps_per_agent, " ==")

    t_first_start = time.time()
    try:
        # main loop
        generation = 0
        timesteps_done = 0
        while timesteps_done < total_timesteps:
            t_generation_start = time.time()

            utils.mpi_print("")
            utils.mpi_print("__ Generation", generation, " __")

            # initialise and evaluate all new agents
            for agent in population:
                #if agent["fit"] < 0: # test/
                if True:  # test constant reevaluation, to dismiss "lucky runs" -> seems good

                    # pick worker from pool and let it work on the agent
                    not_in_work = True
                    while not_in_work:
                        for worker in workers:
                            if worker.can_take_work():
                                worker.work(agent, timesteps_per_agent)
                                not_in_work = False
                                break

                    timesteps_done += timesteps_per_agent * nenvs

            for worker in workers:
                Thread.join(worker.thread)

            # sort by fitness
            population.sort(key=lambda k: k["fit"], reverse=True)

            # print stuff
            fitnesses = [agent["fit"] for agent in population]
            ages = [agent["age"] for agent in population]
            ep_lens = [agent["mean_ep_len"] for agent in population]

            utils.mpi_print(*["{:5.3f}".format(f) for f in fitnesses])
            utils.mpi_print(*["{:5}".format(a) for a in ages])
            utils.mpi_print("__ average fit", "{:.1f}".format(
                np.mean(fitnesses)), ", t_done", timesteps_done, ", took",
                            "{:.1f}".format(time.time() - t_generation_start),
                            "s", ", total",
                            "{:.1f}".format(time.time() - t_first_start),
                            "s __")

            # log stuff
            tb_writer.log_scalar(np.mean(fitnesses), "mean_fit",
                                 timesteps_done)
            tb_writer.log_scalar(np.median(fitnesses), "median_fit",
                                 timesteps_done)
            tb_writer.log_scalar(np.max(fitnesses), "max_fit", timesteps_done)
            tb_writer.log_scalar(np.mean(ages), "mean_age", timesteps_done)
            ep_lens_mean = np.nanmean(ep_lens)
            if (ep_lens_mean):
                tb_writer.log_scalar(ep_lens_mean, "mean_ep_lens",
                                     timesteps_done)

            # cleanup to prevent disk clutter
            to_be_removed = set(
                re.sub(r'\..*$', '', f) for f in os.listdir(sub_dir)) - set(
                    [agent["name"] for agent in population])
            for filename in to_be_removed:
                os.remove(sub_dir + "/" + filename + ".index")
                os.remove(sub_dir + "/" + filename + ".data-00000-of-00001")

            # break when times up
            if not timesteps_done < total_timesteps:
                break

            # mark weak agents for replacement
            cutoff_passthrough = math.floor(population_size * passthrough_perc)
            cutoff_mutating = math.floor(population_size * mutating_perc)
            source_agents = population[:cutoff_mutating]

            new_population = population[:cutoff_passthrough]

            k = 0
            while len(new_population) < population_size:
                new_agent = {
                    "name": source_agents[k]
                    ["name"],  # Take name from source agent, so mutation knows the parent
                    "fit": -1,
                    "need_mut": True,
                    "age": 0
                }
                new_population.append(new_agent)
                k = (k + 1) % len(source_agents)

            population = new_population
            generation += 1

        clean_exit()
    except KeyboardInterrupt:
        clean_exit()

    return 0