def _train(self, rank=0): # Optimizer self.optimizer = self.configure_optimizer() self.rank = rank self.master = rank == 0 # Training Begin self.callback_handler.fire_event(Events.TRAIN_BEGIN) self.train_loader = cycle_wrapper(self.train_loader) # Scheduler self.scheduler = self.configure_scheduler() for _ in range(self.global_iter_count, self.total_num_iterations): if self.local_iter_count == 0: self.callback_handler.fire_event(Events.EPOCH_BEGIN) # The state should be perserved by torch.get_rng_state # However, this solution is not deterministic, but at least it ensures # the randomness when loading data self.batch = next(self.train_loader) # callback handler has access to trainer.batch self.callback_handler.fire_event(Events.BATCH_BEGIN) self.batch = move_to_device(self.batch, self.device) self.batch_results = self.train_iter(self.batch) # Update the model if (self.global_iter_count + 1) % self.config.training.gradient_accumulation_steps == 0: self.callback_handler.fire_event(Events.STEP_BEGIN) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() self.callback_handler.fire_event(Events.STEP_END) self.callback_handler.fire_event(Events.BATCH_END) # Validation if self.master: if (self.global_iter_count + 1) % self.config.training.validation_iterations_interval == 0: if self.validation_loader: self.callback_handler.fire_event(Events.VALIDATE_BEGIN) self.model.eval() self.validate_metrics = self.validate() self.callback_handler.fire_event(Events.VALIDATE_END) self.model.train() if not self.no_epoch_training and (self.local_iter_count + 1) % self.num_training_batches == 0: self.callback_handler.fire_event(Events.EPOCH_END) self.epochs_trained += 1 self.local_iter_count = 0 else: self.local_iter_count += 1 # Post self.global_iter_count += 1 self.callback_handler.fire_event(Events.TRAIN_END) return {}
def train_epoch(self): self.optimizer = self.optimizers[0] self.scheduler = self.schedulers[0] self.local_step_count = 0 iter_train_dataloader = iter(self.train_dataloader) while self.global_step_count < self.total_num_update_steps: self.callback_handler.fire_event(Events.BATCH_BEGIN) # Collect samples buffer_count = 0 while buffer_count < self.ppo_buffer_size: try: batch = next(iter_train_dataloader) except StopIteration: iter_train_dataloader = iter(self.train_dataloader) batch = next(iter_train_dataloader) self.collect_samples(batch) buffer_count += len(batch) # Train all samples in the buffer for mini_batch in self.replay_buffer.iterate_sample( self.ppo_mini_batch_size): # (state, action, action_log_prob, reward, normalized_reward) states, actions, action_log_probs, rewards, normalized_rewards = zip( *mini_batch) for i in range(len(states)): states[i]["target_token_ids"] = actions[i] ppo_batch = self.collate_fn(states) ppo_batch["normalized_rewards"] = torch.LongTensor( normalized_rewards) ppo_batch["old_log_probs"] = torch.FloatTensor( action_log_probs) ppo_batch = move_to_device(ppo_batch, self.device) self.tmp_vars["log_dict"] = self.train_step(ppo_batch) self.callback_handler.fire_event(Events.STEP_BEGIN) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() self.callback_handler.fire_event(Events.STEP_END) self.callback_handler.fire_event(Events.BATCH_END) self.replay_buffer.clear() self.global_step_count += 1 self.local_step_count += 1
def validate(self): self.model.eval() for batch in self.validation_loader: # send to cuda device batch = move_to_device(batch, self.device) with torch.no_grad(): if self.config.training.num_gpus_per_node > 1: self.model.module.predict(batch) else: self.model.predict(batch) # get metrics if self.config.training.num_gpus_per_node > 1: metrics = self.model.module.get_metrics(reset=True) else: metrics = self.model.get_metrics(reset=True) return metrics
def collect_samples(self, batch): """Generate samples, collect rewards, and update replay buffer""" num_human_demos = int(len(batch) * self.mix_human_demo_ratio) # num_generations = int(len(batch) * (1 - self.mix_human_demo_ratio)) self.mix_human_demo_ratio *= self.mix_human_demo_ratio_decay for i in range(0, num_human_demos, self.sample_batch_size): # Update Buffer for Human Demos human_demos_batch = batch[i:i + self.sample_batch_size] human_log_probs = [torch.zeros((item["source_token_ids"].shape[0])) for item in human_demos_batch] human_tokens = [item["source_token_ids"] for item in human_demos_batch] if self.config.text_ppo.constant_human_demo_reward: rewards = np.ones((len(human_demos_batch))) * 2.0 self.replay_buffer.update_batch( states=human_demos_batch, actions=human_tokens, action_log_probs=human_log_probs, rewards=rewards, normalize_reward=False ) else: results = {} results["tokens"] = [item["target_token_ids"] for item in human_demos_batch] rewards = self.reward_func(human_demos_batch, results, is_human_demo=True) self.replay_buffer.update_batch( states=human_demos_batch, actions=human_tokens, action_log_probs=human_log_probs, rewards=rewards ) # Update Buffer for Generations for i in range(num_human_demos, len(batch), self.sample_batch_size): sample_batch = batch[i:i + self.sample_batch_size] sample_batch_collated = self.collator.sample_collate(sample_batch) sample_batch_collated = move_to_device(sample_batch_collated, self.device) results = self.decoder_generate(sample_batch_collated) # TODO: Conside num_return_sequences results["tokens"] = [item[0] for item in results["tokens"]] results["log_probs"] = [item[0] for item in results["log_probs"]] rewards = self.reward_func(sample_batch, results, is_human_demo=False) self.replay_buffer.update_batch( states=sample_batch, actions=results["tokens"], action_log_probs=results["log_probs"], rewards=rewards )
def train_epoch(self): self.optimizer = self.optimizers[0] self.scheduler = self.schedulers[0] self.local_step_count = 0 for batch in self.train_dataloader: self.callback_handler.fire_event(Events.BATCH_BEGIN) batch = move_to_device(batch, self.device) self.tmp_vars["log_dict"] = self.train_step(batch) # Update the model if (self.global_step_count + 1) % self.gradient_accumulation_steps == 0: self.step_update() self.callback_handler.fire_event(Events.BATCH_END) # Only rank 0 can run the validation dataset if self.rank == 0: if self.global_step_count > self.validation_after_num_steps and \ ((self.global_step_count + 1) % self.validation_steps_interval == 0): if self.validation_dataloader is not None: self.model.eval() self.model.is_training = False # BEGIN self.callback_handler.fire_event(Events.VALIDATE_BEGIN) self.tmp_vars["validate_metrics"] = self.validate() self.callback_handler.fire_event(Events.VALIDATE_END) self.model.train() self.model.is_training = True if self.config.training.num_gpus_per_node > 1: torch.distributed.barrier() if self.global_step_count >= self.total_num_steps: break self.global_step_count += 1 self.local_step_count += 1
def validate(self): # Validation self.model.eval() # No gradient is needed for validation with torch.no_grad(): for batch in iter(self.validation_dataloader): # send to cuda device batch = move_to_device(batch, self.device) if self.distributed_training: self.model.module.predict(batch) else: self.model.predict(batch) # END # get metrics if self.distributed_training: metrics = self.model.module.get_metrics(reset=True) else: metrics = self.model.get_metrics(reset=True) return metrics
def train_epoch(self): self.D_optimizer = self.optimizers[0] self.G_optimizer = self.optimizers[1] self.D_scheduler = self.schedulers[0] self.G_scheduler = self.schedulers[1] prop_decay = self.mix_human_demo_ratio / self.total_num_update_steps # total counts self.local_step_count = 0 self.generator_step_count = 0 self.discriminator_step_count = 0 iter_train_dataloader = iter(self.train_dataloader) while self.global_step_count < self.total_num_update_steps: self.callback_handler.fire_event(Events.BATCH_BEGIN) # Collect samples buffer_count = 0 while buffer_count < self.ppo_buffer_size: try: batch = next(iter_train_dataloader) except StopIteration: iter_train_dataloader = iter(self.train_dataloader) self.epochs_trained += 1 logger.info(f"{self.epochs_trained} has finished!") batch = next(iter_train_dataloader) self.collect_samples(batch) buffer_count += len(batch) # Discriminator Warmup if not self.done_discriminator_pretrain: for mini_batch in self.replay_buffer.iterate_sample( self.ppo_mini_batch_size): # (state, action, action_log_prob, reward, normalized_reward) states, actions, action_log_probs, rewards, normalized_rewards = zip( *mini_batch) self.tmp_vars["log_dict"] = self.train_discriminator_step( states, actions) if (self.discriminator_step_count + 1) % self.gradient_accumulation_steps == 0: # self.callback_handler.fire_event(Events.STEP_BEGIN) self.D_optimizer.step() self.D_scheduler.step() self.D_optimizer.zero_grad() # self.callback_handler.fire_event(Events.STEP_END) self.discriminator_step_count += 1 if self.discriminator_step_count >= self.discriminator_pretrain_steps: self.done_discriminator_pretrain = True break else: "Generator Training" for _ in range(self.ppo_epoch): # Train the Generator for mini_batch in self.replay_buffer.iterate_sample( self.ppo_mini_batch_size): # (state, action, action_log_prob, reward, normalized_reward) states, actions, action_log_probs, rewards, normalized_rewards = zip( *mini_batch) ppo_batch = self.collate_fn(states) ppo_batch["target_token_ids"] = pad_sequence( actions, batch_first=True, padding_value=self.pad_token_id) ppo_batch["normalized_rewards"] = torch.LongTensor( normalized_rewards) ppo_batch["old_log_probs"] = torch.FloatTensor( action_log_probs) ppo_batch = move_to_device(ppo_batch, self.device) self.tmp_vars["log_dict"] = self.train_generator_step( ppo_batch) if (self.generator_step_count + 1) % self.gradient_accumulation_steps == 0: # self.callback_handler.fire_event(Events.STEP_BEGIN) self.G_optimizer.step() self.G_scheduler.step() self.G_optimizer.zero_grad() # self.callback_handler.fire_event(Events.STEP_END) self.generator_step_count += 1 "Discriminator Training" for mini_batch in self.replay_buffer.iterate_sample( self.ppo_mini_batch_size): states, actions, action_log_probs, rewards, normalized_rewards = zip( *mini_batch) log_dict = self.train_discriminator_step(states, actions) self.tmp_vars["log_dict"].update(log_dict) if (self.discriminator_step_count + 1) % self.gradient_accumulation_steps == 0: # self.callback_handler.fire_event(Events.STEP_BEGIN) self.D_optimizer.step() self.D_scheduler.step() self.D_optimizer.zero_grad() # self.callback_handler.fire_event(Events.STEP_END) self.discriminator_step_count += 1 # update human mix_human_ratio self.mix_human_demo_ratio -= prop_decay self.tmp_vars["log_dict"][ "mix_human_demo_ratio"] = self.mix_human_demo_ratio self.callback_handler.fire_event(Events.BATCH_END) self.replay_buffer.clear() # Only rank 0 can run the validation dataset if self.rank == 0: if (self.global_step_count + 1) % self.validation_steps_interval == 0: if not self.validation_dataloader is None: self.model.eval() # BEGIN self.callback_handler.fire_event(Events.VALIDATE_BEGIN) self.tmp_vars["validate_metrics"] = self.validate() self.callback_handler.fire_event(Events.VALIDATE_END) self.model.train() self.global_step_count += 1 self.local_step_count += 1
def collect_samples(self, batch): """Generate samples, collect rewards, and update replay buffer""" num_human_demos = int(len(batch) * self.mix_human_demo_ratio) # num_generations = int(len(batch) * (1 - self.mix_human_demo_ratio)) actual_sample_size = min(num_human_demos, self.sample_batch_size) if actual_sample_size > 0: for i in range(0, num_human_demos, actual_sample_size): # Update Buffer for Human Demos human_demos_batch = batch[i:i + actual_sample_size] # collect human demos log probs human_demos_batch_collated = self.collate_fn(human_demos_batch) human_demos_batch_collated = move_to_device( human_demos_batch_collated, self.device) log_probs = self.model.generator.compute_log_probs( human_demos_batch_collated)["log_probs"] human_log_probs = log_probs.tolist() human_tokens = [ item["target_token_ids"] for item in human_demos_batch ] if self.constant_human_demo_reward: rewards = np.ones((len(human_demos_batch))) * 2.0 self.replay_buffer.update_batch( states=human_demos_batch, actions=human_tokens, action_log_probs=human_log_probs, rewards=rewards, normalize_reward=False) else: results = {} results["tokens"] = human_tokens rewards = self.reward_func.get_reward( human_demos_batch, results) self.replay_buffer.update_batch( states=human_demos_batch, actions=human_tokens, action_log_probs=human_log_probs, rewards=rewards) # Update Buffer for Generations actual_sample_size = min( len(batch) - num_human_demos, self.sample_batch_size) for i in range(num_human_demos, len(batch), actual_sample_size): sample_batch = batch[i:i + actual_sample_size] sample_batch_collated = self.collate_fn(sample_batch) sample_batch_collated = move_to_device(sample_batch_collated, self.device) results = self.decoder_generate(sample_batch_collated) # TODO: Consider num_return_sequences results["tokens"] = [item[0] for item in results["tokens"]] # recompute the log probs for better precision if self.recompute_log_probs: temp_target_token_ids = sample_batch_collated[ "target_token_ids"] sample_batch_collated["target_token_ids"] = pad_sequence( results["tokens"], batch_first=True, padding_value=self.pad_token_id).to(self.device) log_probs = self.model.generator.compute_log_probs( sample_batch_collated)["log_probs"] results["log_probs"] = log_probs.tolist() # we switch back the original target_token_ids sample_batch_collated[ "target_token_ids"] = temp_target_token_ids else: results["log_probs"] = [ item[0] for item in results["log_probs"] ] rewards = self.reward_func.get_reward(sample_batch, results) self.replay_buffer.update_batch( states=sample_batch, actions=results["tokens"], action_log_probs=results["log_probs"], rewards=rewards)
def __init__(self, config: DictConfig, model: FlyModel, train_dataloader_fn: Callable, valid_dataloader_fn: Callable = None, test_dataloader_fn: Callable = None): """ Args: config: FlyConfig dictionary model: must be FlyModel dataloader_fn: a Callable function which returns dataloaders """ assert isinstance(model, FlyModel) self.config = config self.rank, self.local_rank = get_rank() # Distributed if self.config.training.num_gpus_per_node > 1: # Init distributed # TODO: multi-node multi-gpu training torch.distributed.init_process_group( backend="nccl", rank=self.rank, world_size=self.config.training.num_gpus_per_node * 1) # configure distributed training self.model = model self.train_dataloader = train_dataloader_fn(config) self.validation_dataloader: Iterable = valid_dataloader_fn( config) if valid_dataloader_fn else None self.test_dataloader = test_dataloader_fn( config) if test_dataloader_fn else None self.callback_handler = CallbackHandler( config, trainer=self, callbacks=[], verbose=config.training.logging.level == "DEBUG") # constants self.gradient_accumulation_steps = config.training.optimization.gradient_accumulation_steps self.validation_steps_interval = config.training.validation.steps_interval self.fp16 = config.training.optimization.fp16 self.fp16_opt_level = config.training.optimization.fp16_opt_level self.distributed_training = False self.total_num_update_steps = int( config.training.total_num.update_steps) self.total_num_steps = self.total_num_update_steps * int( self.gradient_accumulation_steps) self.total_num_epochs = int(self.config.training.total_num.epochs) # Train in epochs or steps if self.total_num_epochs > 0: self.training_in_epoch = True else: if self.total_num_update_steps < 0: raise NotImplementedError( "config.training.total_num.updated_steps must be larger than 0" ) self.training_in_epoch = False self.total_num_epochs = 1 # Number of training batches if self.training_in_epoch: try: self.epoch_num_training_steps = len(self.train_dataloader) self.total_num_training_steps = self.epoch_num_training_steps * self.total_num_epochs self.total_num_update_steps = self.total_num_training_steps // self.gradient_accumulation_steps except TypeError: # connot set the number of total_num_epoch # because it is impossible to know logger.error("Cannot get the length of train dtrainer.model") raise NotImplementedError( "Please specify the `total_num_epochs` or `total_num_update_steps`!" ) else: self.epoch_num_training_steps = self.total_num_update_steps # Validation steps interval self.validation_after_num_steps = config.training.validation.after_num_steps if self.validation_steps_interval < 0: self.validation_steps_interval = self.epoch_num_training_steps - 1 # local variables self.global_step_count = 0 self.epochs_trained = 0 self.local_step_count = 0 # set cuda device if config.training.num_gpus_per_node > 1: torch.cuda.set_device(self.rank) self.device = torch.device("cuda", self.local_rank) elif config.training.num_gpus_per_node == 1: self.device = torch.device("cuda") else: self.device = torch.device("cpu") # Configure optimizers self.optimizers, self.schedulers = self.model.configure_optimizers( self.total_num_update_steps) self.optimizers, self.schedulers = self.configure_optimizers() # Model is sent to GPU or CPU self.model = move_to_device(self.model, self.device) # Mixed-Precision if self.fp16 and self.config.training.num_gpus_per_node > 0: self.configure_fp16() # Distributed Training if self.config.training.num_gpus_per_node > 1: self.configure_ddp() self.configure_callbacks() self.log_keys = set() self.tmp_vars = {} self.callback_handler.fire_event(Events.INITIALIZE) # make sure the model has access to trainer info self.model.set_trainer(self)