def train_epoch(self): self.optimizer = self.optimizers[0] self.scheduler = self.schedulers[0] self.local_step_count = 0 if self.train_dataloader is None: return for batch in self.train_dataloader: self.callback_handler.fire_event(Events.BATCH_BEGIN) batch = move_to_device(batch, self.device) output = self.backward_batch(batch) # Update the model if (self.global_batch_count + 1) % self.gradient_accumulation_batches == 0: # Update the model with optimizer self.step_update(self.model, self.optimizer, self.scheduler) self.global_step_count += 1 self.local_step_count += 1 self.callback_handler.fire_event(Events.BATCH_END) if self.global_step_count >= self.total_num_update_steps: break self.global_batch_count += 1
def test_loop(self, dataloader): self.eval() self.reset_evaluation_metrics() # No gradient is needed for validation with torch.no_grad(): pbar = tqdm.tqdm(dataloader) pbar.mininterval = 2.0 for batch_idx, batch in enumerate(pbar): # send to cuda device batch = move_to_device(batch, self.device) self.test_step(batch, batch_idx)
def train_epoch(self): self.optimizer = self.optimizers[0] self.scheduler = self.schedulers[0] self.local_step_count = 0 iter_train_dataloader = iter(self.train_dataloader) while self.global_step_count < self.total_num_update_steps: self.callback_handler.fire_event(Events.BATCH_BEGIN) # Collect samples buffer_count = 0 while buffer_count < self.ppo_buffer_size: try: batch = next(iter_train_dataloader) except StopIteration: iter_train_dataloader = iter(self.train_dataloader) batch = next(iter_train_dataloader) self.collect_samples(batch) buffer_count += len(batch) # Train all samples in the buffer for mini_batch in self.replay_buffer.iterate_sample( self.ppo_mini_batch_size): # (state, action, action_log_prob, reward, normalized_reward) states, actions, action_log_probs, rewards, normalized_rewards = zip( *mini_batch) for i in range(len(states)): states[i]["target_token_ids"] = actions[i] ppo_batch = self.collate_fn(states) ppo_batch["normalized_rewards"] = torch.LongTensor( normalized_rewards) ppo_batch["old_log_probs"] = torch.FloatTensor( action_log_probs) ppo_batch = move_to_device(ppo_batch, self.device) self.tmp_vars["log_dict"] = self.train_step(ppo_batch) self.callback_handler.fire_event(Events.STEP_BEGIN) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() self.callback_handler.fire_event(Events.STEP_END) self.callback_handler.fire_event(Events.BATCH_END) self.replay_buffer.clear() self.global_step_count += 1 self.local_step_count += 1
def __init__(self, config: DictConfig, model: FlyModel, name: str = "task1", *args, **kwargs): """ Args: config: FlyConfig dictionary model: must be FlyModel dataloader_fn: a Callable function which returns dataloaders """ logger.info("TrainerLoop is initializing!") if not isinstance(model, FlyModel): logger.warn("model is not defined as FlyModel") self.config = config self.model = model self.name = name # class properties self.rank = None self.local_rank = None self.node_rank = None self.world_size = None self.distributed_training = None self.device = None self.fp16 = config.fp16 self.gradient_accumulation_batches = config.gradient_accumulation_batches self.callback_handler = None self.optimizers = [] self.schedulers = [] self.init_distributed_environment() # Model is sent to GPU or CPU self.init_device() # self.optimizers, self.schedulers = self.configure_optimizers() self.model = move_to_device(self.model, self.device) self.model.device = self.device self.init_fp16() if self.distributed_training: self.init_distributed_model(self.model) # make sure the model has access to trainer info self.model.set_trainer(self) self.callback_handler = CallbackHandler(config, trainer=self, callbacks=[], verbose=config.logging.level == "DEBUG") # Configure all callbacks self.configure_callbacks() self.callback_handler.fire_event(Events.INITIALIZE)
def test(self): # Start Testing self.model.eval() self.model.reset_evaluation_metrics() self.callback_handler.fire_event(Events.TEST_BEGIN) # No gradient is needed for test with torch.no_grad(): pbar = tqdm.tqdm(self.test_dataloader) pbar.mininterval = 2.0 for batch in pbar: # send to cuda device batch = move_to_device(batch, self.device) self.model.predict(batch) self.callback_handler.fire_event(Events.TEST_END)
def validate(self): # Start Validation self.model.eval() self.model.reset_evaluation_metrics() self.callback_handler.fire_event(Events.VALIDATE_BEGIN) # No gradient is needed for validation with torch.no_grad(): pbar = tqdm.tqdm(self.validation_dataloader) pbar.mininterval = 2.0 for batch in pbar: # send to cuda device batch = move_to_device(batch, self.device) self.model.predict(batch) self.callback_handler.fire_event(Events.VALIDATE_END)
def train( self, config, train_dataloader, validation_dataloader=None, test_dataloader=None, configure_optimizers=True, stage_name: str = "Stage1", *args, **kwargs, ): self.config = config self.stage_name = stage_name # Model is sent to GPU or CPU self.init_device(config) # self.optimizers, self.schedulers = self.configure_optimizers() self.gradient_accumulation_batches = config.gradient_accumulation_batches self.max_gradient_norm = config.optimization.max_gradient_norm self.fp16 = config.fp16 self.model = move_to_device(self.model, self.device) self.model.device = self.device self.init_fp16(config) if self.distributed_training: self.init_distributed_model(self.model) self.total_num_update_steps = 0 self.total_num_batches = 0 self.total_num_epochs = 0 self.epoch_num_batches = 0 self.global_batch_count = 0 self.global_step_count = 0 self.epochs_trained = 0 self.local_step_count = 0 self.train_dataloader = train_dataloader self.validation_dataloader = validation_dataloader self.test_dataloader = test_dataloader self.init_training_constants(config) if configure_optimizers or len(self.optimizers) == 0: self.configure_optimizers(config, self.total_num_update_steps) # Training begins self.callback_handler.fire_event(Events.TRAIN_BEGIN) while True: self.callback_handler.fire_event(Events.EPOCH_BEGIN) self.train_epoch() self.callback_handler.fire_event(Events.EPOCH_END) self.epochs_trained += 1 if self.training_in_epoch: if self.epochs_trained >= self.total_num_epochs: break else: if self.global_step_count < self.total_num_update_steps: continue else: break # Training ends self.callback_handler.fire_event(Events.TRAIN_END)
def collect_samples(self, batch): """Generate samples, collect rewards, and update replay buffer""" num_human_demos = int(len(batch) * self.mix_human_demo_ratio) # num_generations = int(len(batch) * (1 - self.mix_human_demo_ratio)) self.mix_human_demo_ratio *= self.mix_human_demo_ratio_decay actual_sample_size = min(num_human_demos, self.sample_batch_size) for i in range(0, num_human_demos, actual_sample_size): # Update Buffer for Human Demos human_demos_batch = batch[i:i + actual_sample_size] # collect human demos log probs human_demos_batch_collated = self.collate_fn(human_demos_batch) human_demos_batch_collated = move_to_device( human_demos_batch_collated, self.device) log_probs = self.model.compute_log_probs( human_demos_batch_collated)["log_probs"] human_log_probs = log_probs.tolist() human_tokens = [ item["target_token_ids"] for item in human_demos_batch ] if self.config.text_ppo.constant_human_demo_reward: rewards = np.ones((len(human_demos_batch))) * 2.0 self.replay_buffer.update_batch( states=human_demos_batch, actions=human_tokens, action_log_probs=human_log_probs, rewards=rewards, normalize_reward=False) else: results = {} results["tokens"] = human_tokens rewards = self.reward_func(human_demos_batch_collated, results) self.replay_buffer.update_batch( states=human_demos_batch, actions=human_tokens, action_log_probs=human_log_probs, rewards=rewards) # Update Buffer for Generations actual_sample_size = min( len(batch) - num_human_demos, self.sample_batch_size) for i in range(num_human_demos, len(batch), actual_sample_size): sample_batch = batch[i:i + actual_sample_size] sample_batch_collated = self.collate_fn(sample_batch) sample_batch_collated = move_to_device(sample_batch_collated, self.device) results = self.decoder_generate(sample_batch_collated) # TODO: Consider num_return_sequences results["tokens"] = [item[0] for item in results["tokens"]] if self.recompute_log_probs: for i in range(len(batch)): sample_batch_collated["target_token_ids"] = pad_sequence( results["tokens"], batch_first=True, padding_value=self.pad_token_id).to(self.device) log_probs = self.model.compute_log_probs( sample_batch_collated)["log_probs"] results["log_probs"] = log_probs.tolist() else: results["log_probs"] = [ item[0] for item in results["log_probs"] ] rewards = self.reward_func(sample_batch_collated, results) self.replay_buffer.update_batch( states=sample_batch, actions=results["tokens"], action_log_probs=results["log_probs"], rewards=rewards)
def __init__( self, config: DictConfig, model: FlyModel, train_dataloader_fn: Callable, valid_dataloader_fn: Callable = None, test_dataloader_fn: Callable = None ): """ Args: config: FlyConfig dictionary model: must be FlyModel dataloader_fn: a Callable function which returns dataloaders """ assert isinstance(model, FlyModel) self.config = config self.model = model # For distributed self.rank = int(os.environ.get("RANK", 0)) self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) self.world_size = int(os.environ.get("WORLD_SIZE", 1)) self.distributed_training = (self.world_size > 1) if self.distributed_training and not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend='nccl', init_method='env://') assert torch.distributed.is_initialized() if self.distributed_training and not torch.distributed.is_initialized(): self.node_rank = os.environ.get("NODE_RANK", "N/A") logger.info( f"Initialized Rank:{torch.distributed.get_rank()} Locak-rank: {self.local_rank} on Node:{self.node_rank} Node-name:{socket.gethostname()}" ) logger.info("TrainerLoop is initializing!") # set cuda device if config.training.num_gpus_per_node > 0: torch.cuda.set_device(self.local_rank) self.device = torch.device("cuda", self.local_rank) else: self.device = torch.device("cpu") # Setup the dataloders self.train_dataloader = train_dataloader_fn() if train_dataloader_fn else None # only rank 0 can setup validation and test dataloder if self.rank == 0: self.validation_dataloader: Iterable = valid_dataloader_fn() if valid_dataloader_fn else None self.test_dataloader = test_dataloader_fn() if test_dataloader_fn else None # Setup callback handler self.callback_handler = CallbackHandler( config, trainer=self, callbacks=[], verbose=config.training.logging.level == "DEBUG" ) # constants self.fp16 = config.training.fp16 self.gradient_accumulation_batches = config.training.gradient_accumulation_batches self.setup_training_constants() # local variables self.global_batch_count = 0 self.global_step_count = 0 self.epochs_trained = 0 self.local_step_count = 0 # Configure optimizers self.optimizers, self.schedulers = self.model.configure_optimizers(self.total_num_update_steps) self.optimizers, self.schedulers = self.configure_optimizers() # Model is sent to GPU or CPU self.model = move_to_device(self.model, self.device) # Mixed-Precision if self.fp16: if self.config.training.num_gpus_per_node == 0: raise NotImplementedError("For mixed precision training, you need to use GPU!") self.configure_fp16() # Distributed Training if self.world_size > 1: self.configure_ddp() # Configure all callbacks self.configure_callbacks() self.callback_handler.fire_event(Events.INITIALIZE) # make sure the model has access to trainer info self.model.set_trainer(self)