Exemplo n.º 1
0
    def __init__(
        self,
        trained_model,
        *,
        scope='reward_model',
        use_resource=False,
        is_root=True,
    ):
        self.trained_model = trained_model
        self.hparams = trained_model.hparams()
        self.is_root = is_root

        self.use_resource = use_resource
        self.encoder = self.trained_model.encoding.get_encoder()

        self.scope = scope
        self.model = model.Model(hparams=self.hparams,
                                 scope=f'{scope}/model',
                                 scalar_heads=['reward'])

        self.built = False
        self.padding_token = self.encoder.padding_token

        self.get_rewards = utils.graph_function(
            queries=Schema(tf.int32, (None, None)),
            responses=Schema(tf.int32, (None, None)),
        )(self.get_rewards_op)
Exemplo n.º 2
0
    def __init__(
            self,
            trained_model, *,
            scope=None, use_resource=False,
            embed_queries=lambda queries: queries,
            temperature=1.0, is_root=True,
            build_respond=True,
    ):
        self.trained_model = trained_model
        self.model_hparams = trained_model.hparams()
        self.is_root = is_root

        self.use_resource = use_resource
        self.encoder = self.trained_model.encoding.get_encoder()

        with tf.variable_scope(scope, 'transformer_policy', use_resource=self.use_resource) as s:
            self.scope = s
            self.model = model.Model(
                hparams=self.model_hparams,
                scalar_heads=['value'])

        self.built = False
        self.embed_queries = embed_queries
        self.temperature = temperature
        self.padding_token = self.encoder.padding_token

        if build_respond:
            self.respond = utils.graph_function(
                queries=Schema(tf.int32, (None, None)),
                length=Schema(tf.int32, ()),
            )(self.respond_op)
        self.analyze_responses = utils.graph_function(
            queries=Schema(tf.int32, (None, None)),
            responses=Schema(tf.int32, (None, None)),
        )(self.analyze_responses_op)
Exemplo n.º 3
0
def make_score_fn(hparams, score_model):
    padding_token = score_model.padding_token

    postprocess_fn = lm_tasks.postprocess_fn_from_hparams(hparams, padding_token)
    #decorate requires a named function, postprocess_fn can be anonymous
    @utils.graph_function(responses=Schema(tf.int32, (None, None)))
    def postprocess(responses):
        return postprocess_fn(responses)

    filter_fn = lm_tasks.filter_fn_from_hparams(hparams)
    @utils.graph_function(
        responses=Schema(tf.int32, (None, None)),
        rewards=Schema(tf.float32, (None,)))
    def penalize(responses, rewards):
        valid = filter_fn(responses)
        return tf.where(valid, rewards, hparams.penalty_reward_value * tf.ones_like(rewards))

    @utils.graph_function(
        queries=Schema(tf.int32, (None, None)),
        responses=Schema(tf.int32, (None, None))
    )
    def unpenalized_score_fn(queries, responses):
        return score_model.score_fn(queries, responses)

    def score_fn(queries, responses):
        responses = postprocess(responses)
        score = penalize(responses, unpenalized_score_fn(queries, responses))
        return score, responses, dict(score=score)
    score_fn.stat_schemas = dict(score=Schema(tf.float32, (None,)))
    return score_fn
 def label_schemas(self):
     return dict(difference=Schema(tf.float32, ()))
Exemplo n.º 5
0
    def __init__(self, *, policy, ref_policy, query_sampler, score_fn, hparams, comm):
        self.comm = comm
        self.policy = policy
        self.ref_policy = ref_policy
        self.score_fn = score_fn
        self.hparams = hparams

        if hparams.rewards.adaptive_kl is None:
            self.kl_ctl = FixedKLController(hparams.rewards.kl_coef)
        else:
            self.kl_ctl = AdaptiveKLController(hparams.rewards.kl_coef, hparams=hparams.rewards.adaptive_kl)

        response_length = hparams.task.response_length
        query_length = hparams.task.query_length

        @utils.graph_function()
        def sample_queries():
            return query_sampler()['tokens']
        self.sample_queries = sample_queries

        def compute_rewards(scores, logprobs, ref_logprobs):
            kl = logprobs - ref_logprobs
            non_score_reward = -self.kl_ctl.value * kl
            rewards = non_score_reward.copy()
            rewards[:, -1] += scores
            return rewards, non_score_reward, self.kl_ctl.value
        self.compute_rewards = compute_rewards

        # per rank sizes
        per_rank_rollout_batch_size = utils.exact_div(hparams.ppo.batch_size, comm.Get_size())
        per_rank_minibatch_size = utils.exact_div(per_rank_rollout_batch_size, hparams.ppo.nminibatches)

        @utils.graph_function(
            rollouts=dict(
                queries=Schema(tf.int32, (per_rank_minibatch_size, query_length)),
                responses=Schema(tf.int32, (per_rank_minibatch_size, response_length)),
                values=Schema(tf.float32, (per_rank_minibatch_size, response_length)),
                logprobs=Schema(tf.float32, (per_rank_minibatch_size, response_length)),
                rewards=Schema(tf.float32, (per_rank_minibatch_size, response_length)),
            ))
        def train_minibatch(rollouts):
            """One step of PPO training."""

            left = 1 - policy_frac(hparams)
            lrnow = hparams.ppo.lr * left

            ppo_loss, stats = self.loss(rollouts)
            ppo_train_op = utils.minimize(
                loss=ppo_loss, lr=lrnow, params=policy.get_params(), name='ppo_opt', comm=self.comm)
            return ppo_train_op, stats

        def train(rollouts):
            stat_list = []

            # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
            for ppo_epoch_idx in range(hparams.ppo.noptepochs):
                order = np.random.permutation(per_rank_rollout_batch_size)
                for mb_start in range(0, per_rank_rollout_batch_size, per_rank_minibatch_size):
                    mb_data = {k: v[order[mb_start:mb_start+per_rank_minibatch_size]]
                               for k, v in rollouts.items()}

                    step = tf.train.get_global_step().eval()

                    _, stats = train_minibatch(mb_data)
                    stat_list.append(stats)

            # Collect the stats. (They will be averaged later.)
            return {k: [s[k] for s in stat_list] for k in stat_list[0].keys()}
        self.train = train

        # NOTE: must line up with stats created in self.loss (TODO: better solution?)
        scalar_batch = Schema(tf.float32, (None,))
        ppo_stat_schemas = utils.flatten_dict(dict(
            loss=dict(policy=scalar_batch, value=scalar_batch, total=scalar_batch),
            policy=dict(entropy=scalar_batch, approxkl=scalar_batch, clipfrac=scalar_batch),
            returns=dict(mean=scalar_batch, var=scalar_batch),
            val=dict(vpred=scalar_batch, error=scalar_batch, clipfrac=scalar_batch, mean=scalar_batch, var=scalar_batch),
        ), sep='/')
        stat_data_schemas = dict(
            logprobs=Schema(tf.float32, (None, hparams.task.response_length)),
            ref_logprobs=Schema(tf.float32, (None, hparams.task.response_length)),
            scores=scalar_batch,
            non_score_reward=Schema(tf.float32, (None, hparams.task.response_length)),
            score_stats=score_fn.stat_schemas,
            train_stats=ppo_stat_schemas,
        )
        @utils.graph_function(
            **stat_data_schemas, kl_coef=Schema(tf.float32, ()))
        def record_step_stats(*, kl_coef, **data):
            ppo_summary_writer = utils.get_summary_writer(self.hparams.run.save_dir, subdir='ppo', comm=self.comm)

            kl = data['logprobs'] - data['ref_logprobs']
            mean_kl = tf.reduce_mean(tf.reduce_sum(kl, axis=1))
            mean_entropy = tf.reduce_mean(tf.reduce_sum(-data['logprobs'], axis=1))
            mean_non_score_reward = tf.reduce_mean(tf.reduce_sum(data['non_score_reward'], axis=1))
            stats = {
                'objective/kl': mean_kl,
                'objective/kl_coef': kl_coef,
                'objective/entropy': mean_entropy,
            }
            for k, v in data['train_stats'].items():
                stats[f'ppo/{k}'] = tf.reduce_mean(v, axis=0)
            for k, v in data['score_stats'].items():
                mean = tf.reduce_mean(v, axis=0)
                stats[f'objective/{k}'] = mean
                stats[f'objective/{k}_total'] = mean + mean_non_score_reward

            stats = utils.FlatStats.from_dict(stats).map_flat(
                partial(utils.mpi_allreduce_mean, comm=self.comm)).as_dict()

            # Add more statistics
            step = tf.train.get_global_step().read_value()
            stats['ppo/val/var_explained'] = 1 - stats['ppo/val/error'] / stats['ppo/returns/var']
            steps = step + 1
            stats.update({
                'elapsed/updates': steps,
                'elapsed/steps/serial': steps * hparams.task.response_length,
                'elapsed/steps/total': steps * hparams.ppo.batch_size * hparams.task.response_length,
                'elapsed/episodes': steps * hparams.ppo.batch_size,
            })

            # Time statistics
            total, delta = tf_times()
            stats.update({
                'elapsed/fps': tf.cast(hparams.ppo.batch_size * hparams.task.response_length / delta, tf.int32),
                'elapsed/time': total,
            })
            if ppo_summary_writer:
                record_op = utils.record_stats(
                    stats=stats, summary_writer=ppo_summary_writer, step=step, log_interval=hparams.run.log_interval, name='ppo_stats', comm=self.comm)
            else:
                record_op = tf.no_op()
            return record_op, stats
        self.record_step_stats = record_step_stats
 def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]:
     return dict(
         query=Schema(tf.int32, (query_length,)),
         sample=Schema(tf.int32, (response_length,)),
     )
 def label_schemas(self):
     return dict(
         score=Schema(tf.float32, ()))
 def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]:
     return dict(
         query=Schema(tf.int32, (query_length,)),
         **{f"sample{i}": Schema(tf.int32, (response_length,)) for i in range(self.num_responses)}
     )
 def label_schemas(self):
     return dict(best=Schema(tf.int32, ()))
Exemplo n.º 10
0
    def __init__(self, *, reward_model, policy, query_sampler, hparams, comm):
        self.reward_model = reward_model

        self.policy = policy
        self.hparams = hparams
        self.num_ranks = comm.Get_size()
        self.rank = comm.Get_rank()
        self.comm = comm

        self.label_type = label_types.get(hparams.labels.type)
        self.question_schemas = self.label_type.question_schemas(
            query_length=hparams.task.query_length,
            response_length=hparams.task.response_length,
        )

        data_schemas = {
            **self.question_schemas,
            **self.label_type.label_schemas(),
        }

        with tf.device(None), tf.device('/cpu:0'):
            with tf.variable_scope('label_buffer',
                                   use_resource=True,
                                   initializer=tf.zeros_initializer):
                self.train_buffer = utils.SampleBuffer(
                    capacity=hparams.labels.num_train, schemas=data_schemas)

        with tf.name_scope('train_reward'):
            summary_writer = utils.get_summary_writer(
                self.hparams.run.save_dir, subdir='reward_model', comm=comm)

            @utils.graph_function(indices=Schema(tf.int32, (None, )),
                                  lr=Schema(tf.float32, ()))
            def train_batch(indices, lr):
                with tf.name_scope('minibatch'):
                    minibatch = self.train_buffer.read(indices)
                    stats = self.label_type.loss(
                        reward_model=self.reward_model.get_rewards_op,
                        labels=minibatch)

                    train_op = utils.minimize(
                        loss=stats['loss'],
                        lr=lr,
                        params=self.reward_model.get_params(),
                        name='opt',
                        comm=self.comm)

                    with tf.control_dependencies([train_op]):
                        step_var = tf.get_variable(name='train_step',
                                                   dtype=tf.int64,
                                                   shape=(),
                                                   trainable=False,
                                                   use_resource=True)
                        step = step_var.assign_add(1) - 1

                        stats = utils.FlatStats.from_dict(stats).map_flat(
                            partial(utils.mpi_allreduce_mean,
                                    comm=comm)).as_dict()

                        train_stat_op = utils.record_stats(
                            stats=stats,
                            summary_writer=summary_writer,
                            step=step,
                            log_interval=hparams.run.log_interval,
                            comm=comm)

                return train_stat_op

            self.train_batch = train_batch

        if self.hparams.normalize_before or self.hparams.normalize_after:

            @utils.graph_function()
            def target_mean_std():
                """Returns the means and variances to target for each reward model"""
                # Should be the same on all ranks because the train_buf should be the same
                scales = self.label_type.target_scales(
                    self.train_buffer.data())
                if scales is None:
                    return tf.zeros([]), tf.ones([])
                else:
                    mean, var = tf.nn.moments(scales, axes=[0])
                    return mean, tf.sqrt(var)

            self.target_mean_std = target_mean_std

            def stats(query_responses):
                rewards = np.concatenate([
                    self.reward_model.get_rewards(qs, rs)
                    for qs, rs in query_responses
                ],
                                         axis=0)
                assert len(rewards.shape) == 1, f'{rewards.shape}'
                sums = np.asarray(
                    [rewards.sum(axis=0),
                     np.square(rewards).sum(axis=0)])
                means, sqr_means = self.comm.allreduce(
                    sums, op=MPI.SUM) / (self.num_ranks * rewards.shape[0])
                stds = np.sqrt(sqr_means - means**2)
                return means, stds

            self.stats = stats

            def log_stats_after_normalize(stats):
                if comm.Get_rank() != 0:
                    return
                means, stds = stats
                print(f'after normalize: {means} +- {stds}')

            self.log_stats_after_normalize = log_stats_after_normalize

            def reset_reward_scales():
                self.reward_model.reset_reward_scale()

            self.reset_reward_scales = reset_reward_scales

            def set_reward_norms(mean, std, new_mean, new_std):
                print(f'targets: {new_mean} +- {new_std}')
                print(f'before normalize: {mean} +- {std}')
                assert np.isfinite((mean, std, new_mean, new_std)).all()
                self.reward_model.set_reward_norm(old_mean=mean,
                                                  old_std=std,
                                                  new_mean=new_mean,
                                                  new_std=new_std)

            self.set_reward_norms = set_reward_norms

        if self.hparams.normalize_before or self.hparams.normalize_after:

            @utils.graph_function()
            def sample_policy_batch():
                queries = query_sampler('ref_queries')['tokens']
                responses = policy.respond_op(
                    queries=queries,
                    length=hparams.task.response_length)['responses']
                return queries, responses

            def sample_policy_responses(n_samples):
                n_batches = utils.ceil_div(n_samples,
                                           hparams.rollout_batch_size)
                return [sample_policy_batch() for _ in range(n_batches)]

            self.sample_policy_responses = sample_policy_responses

        @utils.graph_function(labels=utils.add_batch_dim(data_schemas))
        def add_to_buffer(labels):
            return self.train_buffer.add(**labels)

        self.add_to_buffer = add_to_buffer