def test_exact_div():
    assert utils.exact_div(12, 4) == 3
    assert utils.exact_div(12, 3) == 4
    try:
        utils.exact_div(7, 3)
        assert False
    except ValueError:
        pass
Esempio n. 2
0
 def validate(self, *, prefix=''):
     super().validate(prefix=prefix)
     # NOTE: must additionally divide by # ranks
     minibatch_size = utils.exact_div(self.ppo.batch_size, self.ppo.nminibatches)
     if self.ppo.whiten_rewards:
         assert minibatch_size >= 8, \
             f"Minibatch size {minibatch_size} is insufficient for whitening in PPOTrainer.loss"
Esempio n. 3
0
    def train(self):
        labels = download_labels(self.hparams.labels.source,
                                 label_type=self.label_type,
                                 question_schemas=self.question_schemas,
                                 total_labels=self.hparams.labels.num_train,
                                 comm=self.comm)

        self.add_to_buffer(labels)

        if self.hparams.normalize_before:
            target_mean, target_std = self.target_mean_std()
            self.normalize(self.sample_policy_responses, target_mean,
                           target_std)

        # Collect training data for reward model training.  train_indices will include the indices
        # trained on across all ranks, and its size must be a multiple of minibatch_size.
        per_rank_batch_size = utils.exact_div(self.hparams.batch_size,
                                              self.num_ranks)

        # Make sure each rank gets the same shuffle so we train on each point exactly once
        train_indices = self.comm.bcast(
            np.random.permutation(self.hparams.labels.num_train))

        # Train on train_indices
        print(self.rank, "training on", self.hparams.labels.num_train,
              "in batches of", per_rank_batch_size)
        for start_index in range(0, self.hparams.labels.num_train,
                                 self.hparams.batch_size):
            end_index = start_index + self.hparams.batch_size
            all_ranks_indices = train_indices[start_index:end_index]
            our_indices = all_ranks_indices[self.rank::self.num_ranks]
            lr = (1 - start_index /
                  self.hparams.labels.num_train) * self.hparams.lr
            self.train_batch(our_indices, lr)

        if self.hparams.normalize_after:
            target_mean, target_std = np.zeros([]), np.ones([])
            self.normalize(self.sample_policy_responses, target_mean,
                           target_std)
Esempio n. 4
0
def past_shape(*, hparams, batch_size=None, sequence=None):
    return [
        batch_size, hparams.n_layer, 2, hparams.n_head, sequence,
        utils.exact_div(hparams.n_embd, hparams.n_head)
    ]
Esempio n. 5
0
def sample_policy(save_dir=None,
                  savescope='policy',
                  temperature=1.0,
                  seed=None,
                  batch_size=4,
                  nsamples=0):
    hparams = train_policy.HParams()
    hparams.override_from_json_file(
        os.path.join(save_dir, 'train_policy_hparams.json'))
    print('hparams', hparams)
    task = hparams.task

    comm = MPI.COMM_WORLD
    nsamples_per_rank = utils.exact_div(nsamples, comm.Get_size())
    with tf.Graph().as_default():
        m = trained_models.TrainedModel(name='sample',
                                        savedir=os.path.join(
                                            save_dir, 'policy'),
                                        scope='policy')
        encoder = m.encoding.get_encoder()
        hyperparams.dump(m.hparams(), name='model_hparams')

        utils.set_mpi_seed(seed)

        policy = Policy(
            m,
            scope='policy',
            is_root=True,  # just init on every rank, simplifies code
            embed_queries=lm_tasks.query_formatter(task, encoder),
            temperature=temperature,
        )

        query_sampler = lm_tasks.make_query_sampler(hparams=task,
                                                    encoder=encoder,
                                                    comm=comm,
                                                    batch_size=batch_size,
                                                    mode='test')

        init_ops = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer(),
        )

        with utils.mpi_session() as sess:
            init_ops.run()

            @utils.graph_function()
            def sample_queries():
                return query_sampler()['tokens']

            tf.get_default_graph().finalize()

            generated = 0
            while nsamples_per_rank == 0 or generated < nsamples_per_rank:
                queries = sample_queries()
                rollouts = policy.respond(queries, length=task.response_length)
                assert len(queries.tolist()) == batch_size
                assert len(rollouts['responses'].tolist()) == batch_size
                for q, r in zip(queries.tolist(),
                                rollouts['responses'].tolist()):
                    print('=' * 80)
                    print(encoder.decode(q).replace("\n", "⏎"))
                    print(encoder.decode(r).replace("\n", "⏎"))
                generated += batch_size
Esempio n. 6
0
def train(hparams: HParams):
    save_dir = hparams.run.save_dir
    if hparams.rewards.train_new_model:
        assert hparams.task == hparams.rewards.train_new_model.task, f'{hparams.task} != {hparams.rewards.train_new_model.task}'
        hparams.rewards.train_new_model.run.save_dir = save_dir
        train_reward.train(hparams.rewards.train_new_model)
        if 'pytest' in sys.modules:
            hparams.rewards.trained_model = 'test'
        elif save_dir:
            hparams.rewards.trained_model = None if save_dir is None else os.path.join(save_dir, 'reward_model')

    comm = MPI.COMM_WORLD

    with tf.Graph().as_default():
        hyperparams.dump(hparams)

        m = trained_models.TrainedModel(hparams.task.policy.initial_model)
        encoder = m.encoding.get_encoder()
        hyperparams.dump(m.hparams(), name='model_hparams')

        if save_dir:
            if not save_dir.startswith('gs://'):
                os.makedirs(os.path.join(save_dir, 'policy'), exist_ok=True)
            with tf.gfile.Open(os.path.join(save_dir, 'train_policy_hparams.json'), 'w') as f:
                json.dump(hparams.to_nested_dict(), f, indent=2)
            with tf.gfile.Open(os.path.join(save_dir, 'policy', 'hparams.json'), 'w') as f:
                json.dump(m.hparams().to_nested_dict(), f, indent=2)
            with tf.gfile.Open(os.path.join(save_dir, 'policy', 'encoding'), 'w') as f:
                json.dump(m.encoding.name, f, indent=2)
        utils.set_mpi_seed(hparams.run.seed)

        score_model = TrainedRewardModel(hparams.rewards.trained_model, m.encoding, comm=comm)

        ref_policy = Policy(
            m, scope='ref_policy',
            is_root=comm.Get_rank() == 0,
            embed_queries=lm_tasks.query_formatter(hparams.task, encoder),
            temperature=hparams.task.policy.temperature,
            build_respond=False)

        policy = Policy(
            m, scope='policy',
            is_root=comm.Get_rank() == 0,
            embed_queries=lm_tasks.query_formatter(hparams.task, encoder),
            temperature=hparams.task.policy.temperature)

        query_sampler = lm_tasks.make_query_sampler(
            hparams=hparams.task, encoder=encoder, comm=comm,
            batch_size=utils.exact_div(hparams.ppo.batch_size, comm.Get_size()),
        )

        per_rank_minibatch_size = utils.exact_div(hparams.ppo.batch_size, hparams.ppo.nminibatches * comm.Get_size())
        if hparams.ppo.whiten_rewards:
            assert per_rank_minibatch_size >= 8, \
                f"Per-rank minibatch size {per_rank_minibatch_size} is insufficient for whitening"

        global_step = tf.train.get_or_create_global_step()
        increment_global_step = tf.group(global_step.assign_add(1))

        with utils.variables_on_gpu():

            ppo_trainer = PPOTrainer(
                policy=policy, ref_policy=ref_policy, query_sampler=query_sampler,
                score_fn=make_score_fn(hparams.task, score_model=score_model),
                hparams=hparams, comm=comm)

        if comm.Get_rank() == 0 and save_dir:
            print(f"Will save to {save_dir}")
            saver = tf.train.Saver(max_to_keep=20, save_relative_paths=True)
            checkpoint_dir = os.path.join(save_dir, 'policy/checkpoints/model.ckpt')
        else:
            saver = None
            checkpoint_dir = None

        @utils.graph_function()
        def sync_models():
            score_model.ensure_built()
            return utils.variable_synchronizer(comm, vars=score_model.get_params() + ref_policy.get_params() + policy.get_params())

        init_ops = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer(),
            summary.summary_writer_initializer_op())

        with utils.mpi_session() as sess:
            init_ops.run()

            sync_models()

            tf.get_default_graph().finalize()

            try:
                while global_step.eval() < nupdates(hparams):
                    ppo_trainer.step()
                    increment_global_step.run()

                    if saver and global_step.eval() % hparams.run.save_interval == 0:
                        saver.save(sess, checkpoint_dir, global_step=global_step)
            finally:
                if saver:
                    saver.save(sess, checkpoint_dir, global_step=global_step)
Esempio n. 7
0
    def __init__(self, *, policy, ref_policy, query_sampler, score_fn, hparams, comm):
        self.comm = comm
        self.policy = policy
        self.ref_policy = ref_policy
        self.score_fn = score_fn
        self.hparams = hparams

        if hparams.rewards.adaptive_kl is None:
            self.kl_ctl = FixedKLController(hparams.rewards.kl_coef)
        else:
            self.kl_ctl = AdaptiveKLController(hparams.rewards.kl_coef, hparams=hparams.rewards.adaptive_kl)

        response_length = hparams.task.response_length
        query_length = hparams.task.query_length

        @utils.graph_function()
        def sample_queries():
            return query_sampler()['tokens']
        self.sample_queries = sample_queries

        def compute_rewards(scores, logprobs, ref_logprobs):
            kl = logprobs - ref_logprobs
            non_score_reward = -self.kl_ctl.value * kl
            rewards = non_score_reward.copy()
            rewards[:, -1] += scores
            return rewards, non_score_reward, self.kl_ctl.value
        self.compute_rewards = compute_rewards

        # per rank sizes
        per_rank_rollout_batch_size = utils.exact_div(hparams.ppo.batch_size, comm.Get_size())
        per_rank_minibatch_size = utils.exact_div(per_rank_rollout_batch_size, hparams.ppo.nminibatches)

        @utils.graph_function(
            rollouts=dict(
                queries=Schema(tf.int32, (per_rank_minibatch_size, query_length)),
                responses=Schema(tf.int32, (per_rank_minibatch_size, response_length)),
                values=Schema(tf.float32, (per_rank_minibatch_size, response_length)),
                logprobs=Schema(tf.float32, (per_rank_minibatch_size, response_length)),
                rewards=Schema(tf.float32, (per_rank_minibatch_size, response_length)),
            ))
        def train_minibatch(rollouts):
            """One step of PPO training."""

            left = 1 - policy_frac(hparams)
            lrnow = hparams.ppo.lr * left

            ppo_loss, stats = self.loss(rollouts)
            ppo_train_op = utils.minimize(
                loss=ppo_loss, lr=lrnow, params=policy.get_params(), name='ppo_opt', comm=self.comm)
            return ppo_train_op, stats

        def train(rollouts):
            stat_list = []

            # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
            for ppo_epoch_idx in range(hparams.ppo.noptepochs):
                order = np.random.permutation(per_rank_rollout_batch_size)
                for mb_start in range(0, per_rank_rollout_batch_size, per_rank_minibatch_size):
                    mb_data = {k: v[order[mb_start:mb_start+per_rank_minibatch_size]]
                               for k, v in rollouts.items()}

                    step = tf.train.get_global_step().eval()

                    _, stats = train_minibatch(mb_data)
                    stat_list.append(stats)

            # Collect the stats. (They will be averaged later.)
            return {k: [s[k] for s in stat_list] for k in stat_list[0].keys()}
        self.train = train

        # NOTE: must line up with stats created in self.loss (TODO: better solution?)
        scalar_batch = Schema(tf.float32, (None,))
        ppo_stat_schemas = utils.flatten_dict(dict(
            loss=dict(policy=scalar_batch, value=scalar_batch, total=scalar_batch),
            policy=dict(entropy=scalar_batch, approxkl=scalar_batch, clipfrac=scalar_batch),
            returns=dict(mean=scalar_batch, var=scalar_batch),
            val=dict(vpred=scalar_batch, error=scalar_batch, clipfrac=scalar_batch, mean=scalar_batch, var=scalar_batch),
        ), sep='/')
        stat_data_schemas = dict(
            logprobs=Schema(tf.float32, (None, hparams.task.response_length)),
            ref_logprobs=Schema(tf.float32, (None, hparams.task.response_length)),
            scores=scalar_batch,
            non_score_reward=Schema(tf.float32, (None, hparams.task.response_length)),
            score_stats=score_fn.stat_schemas,
            train_stats=ppo_stat_schemas,
        )
        @utils.graph_function(
            **stat_data_schemas, kl_coef=Schema(tf.float32, ()))
        def record_step_stats(*, kl_coef, **data):
            ppo_summary_writer = utils.get_summary_writer(self.hparams.run.save_dir, subdir='ppo', comm=self.comm)

            kl = data['logprobs'] - data['ref_logprobs']
            mean_kl = tf.reduce_mean(tf.reduce_sum(kl, axis=1))
            mean_entropy = tf.reduce_mean(tf.reduce_sum(-data['logprobs'], axis=1))
            mean_non_score_reward = tf.reduce_mean(tf.reduce_sum(data['non_score_reward'], axis=1))
            stats = {
                'objective/kl': mean_kl,
                'objective/kl_coef': kl_coef,
                'objective/entropy': mean_entropy,
            }
            for k, v in data['train_stats'].items():
                stats[f'ppo/{k}'] = tf.reduce_mean(v, axis=0)
            for k, v in data['score_stats'].items():
                mean = tf.reduce_mean(v, axis=0)
                stats[f'objective/{k}'] = mean
                stats[f'objective/{k}_total'] = mean + mean_non_score_reward

            stats = utils.FlatStats.from_dict(stats).map_flat(
                partial(utils.mpi_allreduce_mean, comm=self.comm)).as_dict()

            # Add more statistics
            step = tf.train.get_global_step().read_value()
            stats['ppo/val/var_explained'] = 1 - stats['ppo/val/error'] / stats['ppo/returns/var']
            steps = step + 1
            stats.update({
                'elapsed/updates': steps,
                'elapsed/steps/serial': steps * hparams.task.response_length,
                'elapsed/steps/total': steps * hparams.ppo.batch_size * hparams.task.response_length,
                'elapsed/episodes': steps * hparams.ppo.batch_size,
            })

            # Time statistics
            total, delta = tf_times()
            stats.update({
                'elapsed/fps': tf.cast(hparams.ppo.batch_size * hparams.task.response_length / delta, tf.int32),
                'elapsed/time': total,
            })
            if ppo_summary_writer:
                record_op = utils.record_stats(
                    stats=stats, summary_writer=ppo_summary_writer, step=step, log_interval=hparams.run.log_interval, name='ppo_stats', comm=self.comm)
            else:
                record_op = tf.no_op()
            return record_op, stats
        self.record_step_stats = record_step_stats
Esempio n. 8
0
 def validate(self, *, prefix=''):
     super().validate(prefix=prefix)
     utils.exact_div(self.labels.num_train, self.batch_size)
Esempio n. 9
0
def train(hparams: HParams):
    with tf.Graph().as_default():
        hyperparams.dump(hparams)
        utils.set_mpi_seed(hparams.run.seed)

        m = trained_models.TrainedModel(hparams.task.policy.initial_model)
        encoder = m.encoding.get_encoder()
        hyperparams.dump(m.hparams(), name='model_hparams')

        comm = MPI.COMM_WORLD
        ref_policy = Policy(m,
                            scope='ref_policy',
                            is_root=comm.Get_rank() == 0,
                            embed_queries=lm_tasks.query_formatter(
                                hparams.task, encoder),
                            temperature=hparams.task.policy.temperature,
                            build_respond=False)

        reward_model = rewards.RewardModelTrainer(m,
                                                  is_root=comm.Get_rank() == 0)

        query_sampler = lm_tasks.make_query_sampler(
            hparams=hparams.task,
            encoder=encoder,
            comm=comm,
            batch_size=utils.exact_div(hparams.rollout_batch_size,
                                       comm.Get_size()))

        tf.train.create_global_step()

        reward_trainer = RewardModelTrainer(
            reward_model=reward_model,
            policy=ref_policy,
            query_sampler=query_sampler,
            hparams=hparams,
            comm=comm,
        )

        save_dir = hparams.run.save_dir
        if comm.Get_rank() == 0 and save_dir:
            print(f"Will save to {save_dir}")
            saver = tf.train.Saver(max_to_keep=20, save_relative_paths=True)
            checkpoint_dir = os.path.join(
                save_dir, 'reward_model/checkpoints/model.ckpt')

            if not save_dir.startswith('gs://'):
                os.makedirs(os.path.join(save_dir, 'reward_model'),
                            exist_ok=True)
            with tf.gfile.Open(
                    os.path.join(save_dir, 'train_reward_hparams.json'),
                    'w') as f:
                json.dump(hparams.to_nested_dict(), f, indent=2)
            with tf.gfile.Open(
                    os.path.join(save_dir, 'reward_model', 'hparams.json'),
                    'w') as f:
                json.dump(reward_model.hparams.to_nested_dict(), f, indent=2)
            with tf.gfile.Open(
                    os.path.join(save_dir, 'reward_model', 'encoding'),
                    'w') as f:
                json.dump(reward_model.trained_model.encoding.name,
                          f,
                          indent=2)
        else:
            saver = None
            checkpoint_dir = None

        with utils.variables_on_gpu():
            init_ops = tf.group(tf.global_variables_initializer(),
                                tf.local_variables_initializer(),
                                summary.summary_writer_initializer_op())

            @utils.graph_function()
            def sync_models():
                return utils.variable_synchronizer(
                    comm,
                    vars=ref_policy.get_params() + reward_model.get_params())

        tf.get_default_graph().finalize()

        with utils.mpi_session() as sess:
            init_ops.run()
            sync_models()

            reward_trainer.train()

            if saver:
                saver.save(sess, checkpoint_dir)