def __init__( self, trained_model, *, scope='reward_model', use_resource=False, is_root=True, ): self.trained_model = trained_model self.hparams = trained_model.hparams() self.is_root = is_root self.use_resource = use_resource self.encoder = self.trained_model.encoding.get_encoder() self.scope = scope self.model = model.Model(hparams=self.hparams, scope=f'{scope}/model', scalar_heads=['reward']) self.built = False self.padding_token = self.encoder.padding_token self.get_rewards = utils.graph_function( queries=Schema(tf.int32, (None, None)), responses=Schema(tf.int32, (None, None)), )(self.get_rewards_op)
def __init__( self, trained_model, *, scope=None, use_resource=False, embed_queries=lambda queries: queries, temperature=1.0, is_root=True, build_respond=True, ): self.trained_model = trained_model self.model_hparams = trained_model.hparams() self.is_root = is_root self.use_resource = use_resource self.encoder = self.trained_model.encoding.get_encoder() with tf.variable_scope(scope, 'transformer_policy', use_resource=self.use_resource) as s: self.scope = s self.model = model.Model( hparams=self.model_hparams, scalar_heads=['value']) self.built = False self.embed_queries = embed_queries self.temperature = temperature self.padding_token = self.encoder.padding_token if build_respond: self.respond = utils.graph_function( queries=Schema(tf.int32, (None, None)), length=Schema(tf.int32, ()), )(self.respond_op) self.analyze_responses = utils.graph_function( queries=Schema(tf.int32, (None, None)), responses=Schema(tf.int32, (None, None)), )(self.analyze_responses_op)
def make_score_fn(hparams, score_model): padding_token = score_model.padding_token postprocess_fn = lm_tasks.postprocess_fn_from_hparams(hparams, padding_token) #decorate requires a named function, postprocess_fn can be anonymous @utils.graph_function(responses=Schema(tf.int32, (None, None))) def postprocess(responses): return postprocess_fn(responses) filter_fn = lm_tasks.filter_fn_from_hparams(hparams) @utils.graph_function( responses=Schema(tf.int32, (None, None)), rewards=Schema(tf.float32, (None,))) def penalize(responses, rewards): valid = filter_fn(responses) return tf.where(valid, rewards, hparams.penalty_reward_value * tf.ones_like(rewards)) @utils.graph_function( queries=Schema(tf.int32, (None, None)), responses=Schema(tf.int32, (None, None)) ) def unpenalized_score_fn(queries, responses): return score_model.score_fn(queries, responses) def score_fn(queries, responses): responses = postprocess(responses) score = penalize(responses, unpenalized_score_fn(queries, responses)) return score, responses, dict(score=score) score_fn.stat_schemas = dict(score=Schema(tf.float32, (None,))) return score_fn
def label_schemas(self): return dict(difference=Schema(tf.float32, ()))
def __init__(self, *, policy, ref_policy, query_sampler, score_fn, hparams, comm): self.comm = comm self.policy = policy self.ref_policy = ref_policy self.score_fn = score_fn self.hparams = hparams if hparams.rewards.adaptive_kl is None: self.kl_ctl = FixedKLController(hparams.rewards.kl_coef) else: self.kl_ctl = AdaptiveKLController(hparams.rewards.kl_coef, hparams=hparams.rewards.adaptive_kl) response_length = hparams.task.response_length query_length = hparams.task.query_length @utils.graph_function() def sample_queries(): return query_sampler()['tokens'] self.sample_queries = sample_queries def compute_rewards(scores, logprobs, ref_logprobs): kl = logprobs - ref_logprobs non_score_reward = -self.kl_ctl.value * kl rewards = non_score_reward.copy() rewards[:, -1] += scores return rewards, non_score_reward, self.kl_ctl.value self.compute_rewards = compute_rewards # per rank sizes per_rank_rollout_batch_size = utils.exact_div(hparams.ppo.batch_size, comm.Get_size()) per_rank_minibatch_size = utils.exact_div(per_rank_rollout_batch_size, hparams.ppo.nminibatches) @utils.graph_function( rollouts=dict( queries=Schema(tf.int32, (per_rank_minibatch_size, query_length)), responses=Schema(tf.int32, (per_rank_minibatch_size, response_length)), values=Schema(tf.float32, (per_rank_minibatch_size, response_length)), logprobs=Schema(tf.float32, (per_rank_minibatch_size, response_length)), rewards=Schema(tf.float32, (per_rank_minibatch_size, response_length)), )) def train_minibatch(rollouts): """One step of PPO training.""" left = 1 - policy_frac(hparams) lrnow = hparams.ppo.lr * left ppo_loss, stats = self.loss(rollouts) ppo_train_op = utils.minimize( loss=ppo_loss, lr=lrnow, params=policy.get_params(), name='ppo_opt', comm=self.comm) return ppo_train_op, stats def train(rollouts): stat_list = [] # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch for ppo_epoch_idx in range(hparams.ppo.noptepochs): order = np.random.permutation(per_rank_rollout_batch_size) for mb_start in range(0, per_rank_rollout_batch_size, per_rank_minibatch_size): mb_data = {k: v[order[mb_start:mb_start+per_rank_minibatch_size]] for k, v in rollouts.items()} step = tf.train.get_global_step().eval() _, stats = train_minibatch(mb_data) stat_list.append(stats) # Collect the stats. (They will be averaged later.) return {k: [s[k] for s in stat_list] for k in stat_list[0].keys()} self.train = train # NOTE: must line up with stats created in self.loss (TODO: better solution?) scalar_batch = Schema(tf.float32, (None,)) ppo_stat_schemas = utils.flatten_dict(dict( loss=dict(policy=scalar_batch, value=scalar_batch, total=scalar_batch), policy=dict(entropy=scalar_batch, approxkl=scalar_batch, clipfrac=scalar_batch), returns=dict(mean=scalar_batch, var=scalar_batch), val=dict(vpred=scalar_batch, error=scalar_batch, clipfrac=scalar_batch, mean=scalar_batch, var=scalar_batch), ), sep='/') stat_data_schemas = dict( logprobs=Schema(tf.float32, (None, hparams.task.response_length)), ref_logprobs=Schema(tf.float32, (None, hparams.task.response_length)), scores=scalar_batch, non_score_reward=Schema(tf.float32, (None, hparams.task.response_length)), score_stats=score_fn.stat_schemas, train_stats=ppo_stat_schemas, ) @utils.graph_function( **stat_data_schemas, kl_coef=Schema(tf.float32, ())) def record_step_stats(*, kl_coef, **data): ppo_summary_writer = utils.get_summary_writer(self.hparams.run.save_dir, subdir='ppo', comm=self.comm) kl = data['logprobs'] - data['ref_logprobs'] mean_kl = tf.reduce_mean(tf.reduce_sum(kl, axis=1)) mean_entropy = tf.reduce_mean(tf.reduce_sum(-data['logprobs'], axis=1)) mean_non_score_reward = tf.reduce_mean(tf.reduce_sum(data['non_score_reward'], axis=1)) stats = { 'objective/kl': mean_kl, 'objective/kl_coef': kl_coef, 'objective/entropy': mean_entropy, } for k, v in data['train_stats'].items(): stats[f'ppo/{k}'] = tf.reduce_mean(v, axis=0) for k, v in data['score_stats'].items(): mean = tf.reduce_mean(v, axis=0) stats[f'objective/{k}'] = mean stats[f'objective/{k}_total'] = mean + mean_non_score_reward stats = utils.FlatStats.from_dict(stats).map_flat( partial(utils.mpi_allreduce_mean, comm=self.comm)).as_dict() # Add more statistics step = tf.train.get_global_step().read_value() stats['ppo/val/var_explained'] = 1 - stats['ppo/val/error'] / stats['ppo/returns/var'] steps = step + 1 stats.update({ 'elapsed/updates': steps, 'elapsed/steps/serial': steps * hparams.task.response_length, 'elapsed/steps/total': steps * hparams.ppo.batch_size * hparams.task.response_length, 'elapsed/episodes': steps * hparams.ppo.batch_size, }) # Time statistics total, delta = tf_times() stats.update({ 'elapsed/fps': tf.cast(hparams.ppo.batch_size * hparams.task.response_length / delta, tf.int32), 'elapsed/time': total, }) if ppo_summary_writer: record_op = utils.record_stats( stats=stats, summary_writer=ppo_summary_writer, step=step, log_interval=hparams.run.log_interval, name='ppo_stats', comm=self.comm) else: record_op = tf.no_op() return record_op, stats self.record_step_stats = record_step_stats
def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]: return dict( query=Schema(tf.int32, (query_length,)), sample=Schema(tf.int32, (response_length,)), )
def label_schemas(self): return dict( score=Schema(tf.float32, ()))
def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]: return dict( query=Schema(tf.int32, (query_length,)), **{f"sample{i}": Schema(tf.int32, (response_length,)) for i in range(self.num_responses)} )
def label_schemas(self): return dict(best=Schema(tf.int32, ()))
def __init__(self, *, reward_model, policy, query_sampler, hparams, comm): self.reward_model = reward_model self.policy = policy self.hparams = hparams self.num_ranks = comm.Get_size() self.rank = comm.Get_rank() self.comm = comm self.label_type = label_types.get(hparams.labels.type) self.question_schemas = self.label_type.question_schemas( query_length=hparams.task.query_length, response_length=hparams.task.response_length, ) data_schemas = { **self.question_schemas, **self.label_type.label_schemas(), } with tf.device(None), tf.device('/cpu:0'): with tf.variable_scope('label_buffer', use_resource=True, initializer=tf.zeros_initializer): self.train_buffer = utils.SampleBuffer( capacity=hparams.labels.num_train, schemas=data_schemas) with tf.name_scope('train_reward'): summary_writer = utils.get_summary_writer( self.hparams.run.save_dir, subdir='reward_model', comm=comm) @utils.graph_function(indices=Schema(tf.int32, (None, )), lr=Schema(tf.float32, ())) def train_batch(indices, lr): with tf.name_scope('minibatch'): minibatch = self.train_buffer.read(indices) stats = self.label_type.loss( reward_model=self.reward_model.get_rewards_op, labels=minibatch) train_op = utils.minimize( loss=stats['loss'], lr=lr, params=self.reward_model.get_params(), name='opt', comm=self.comm) with tf.control_dependencies([train_op]): step_var = tf.get_variable(name='train_step', dtype=tf.int64, shape=(), trainable=False, use_resource=True) step = step_var.assign_add(1) - 1 stats = utils.FlatStats.from_dict(stats).map_flat( partial(utils.mpi_allreduce_mean, comm=comm)).as_dict() train_stat_op = utils.record_stats( stats=stats, summary_writer=summary_writer, step=step, log_interval=hparams.run.log_interval, comm=comm) return train_stat_op self.train_batch = train_batch if self.hparams.normalize_before or self.hparams.normalize_after: @utils.graph_function() def target_mean_std(): """Returns the means and variances to target for each reward model""" # Should be the same on all ranks because the train_buf should be the same scales = self.label_type.target_scales( self.train_buffer.data()) if scales is None: return tf.zeros([]), tf.ones([]) else: mean, var = tf.nn.moments(scales, axes=[0]) return mean, tf.sqrt(var) self.target_mean_std = target_mean_std def stats(query_responses): rewards = np.concatenate([ self.reward_model.get_rewards(qs, rs) for qs, rs in query_responses ], axis=0) assert len(rewards.shape) == 1, f'{rewards.shape}' sums = np.asarray( [rewards.sum(axis=0), np.square(rewards).sum(axis=0)]) means, sqr_means = self.comm.allreduce( sums, op=MPI.SUM) / (self.num_ranks * rewards.shape[0]) stds = np.sqrt(sqr_means - means**2) return means, stds self.stats = stats def log_stats_after_normalize(stats): if comm.Get_rank() != 0: return means, stds = stats print(f'after normalize: {means} +- {stds}') self.log_stats_after_normalize = log_stats_after_normalize def reset_reward_scales(): self.reward_model.reset_reward_scale() self.reset_reward_scales = reset_reward_scales def set_reward_norms(mean, std, new_mean, new_std): print(f'targets: {new_mean} +- {new_std}') print(f'before normalize: {mean} +- {std}') assert np.isfinite((mean, std, new_mean, new_std)).all() self.reward_model.set_reward_norm(old_mean=mean, old_std=std, new_mean=new_mean, new_std=new_std) self.set_reward_norms = set_reward_norms if self.hparams.normalize_before or self.hparams.normalize_after: @utils.graph_function() def sample_policy_batch(): queries = query_sampler('ref_queries')['tokens'] responses = policy.respond_op( queries=queries, length=hparams.task.response_length)['responses'] return queries, responses def sample_policy_responses(n_samples): n_batches = utils.ceil_div(n_samples, hparams.rollout_batch_size) return [sample_policy_batch() for _ in range(n_batches)] self.sample_policy_responses = sample_policy_responses @utils.graph_function(labels=utils.add_batch_dim(data_schemas)) def add_to_buffer(labels): return self.train_buffer.add(**labels) self.add_to_buffer = add_to_buffer