Beispiel #1
0
    def train_epoch(self):
        self.optimizer = self.optimizers[0]
        self.scheduler = self.schedulers[0]

        self.local_step_count = 0

        if self.train_dataloader is None:
            return

        for batch in self.train_dataloader:
            self.callback_handler.fire_event(Events.BATCH_BEGIN)

            batch = move_to_device(batch, self.device)
            output = self.backward_batch(batch)

            # Update the model
            if (self.global_batch_count + 1) % self.gradient_accumulation_batches == 0:
                # Update the model with optimizer
                self.step_update(self.model, self.optimizer, self.scheduler)
                self.global_step_count += 1
                self.local_step_count += 1

            self.callback_handler.fire_event(Events.BATCH_END)

            if self.global_step_count >= self.total_num_update_steps:
                break

            self.global_batch_count += 1
Beispiel #2
0
 def test_loop(self, dataloader):
     self.eval()
     self.reset_evaluation_metrics()
     # No gradient is needed for validation
     with torch.no_grad():
         pbar = tqdm.tqdm(dataloader)
         pbar.mininterval = 2.0
         for batch_idx, batch in enumerate(pbar):
             # send to cuda device
             batch = move_to_device(batch, self.device)
             self.test_step(batch, batch_idx)
Beispiel #3
0
    def train_epoch(self):
        self.optimizer = self.optimizers[0]
        self.scheduler = self.schedulers[0]

        self.local_step_count = 0

        iter_train_dataloader = iter(self.train_dataloader)

        while self.global_step_count < self.total_num_update_steps:
            self.callback_handler.fire_event(Events.BATCH_BEGIN)

            # Collect samples
            buffer_count = 0
            while buffer_count < self.ppo_buffer_size:
                try:
                    batch = next(iter_train_dataloader)
                except StopIteration:
                    iter_train_dataloader = iter(self.train_dataloader)
                    batch = next(iter_train_dataloader)

                self.collect_samples(batch)
                buffer_count += len(batch)

            # Train all samples in the buffer
            for mini_batch in self.replay_buffer.iterate_sample(
                    self.ppo_mini_batch_size):

                # (state, action, action_log_prob, reward, normalized_reward)
                states, actions, action_log_probs, rewards, normalized_rewards = zip(
                    *mini_batch)

                for i in range(len(states)):
                    states[i]["target_token_ids"] = actions[i]

                ppo_batch = self.collate_fn(states)
                ppo_batch["normalized_rewards"] = torch.LongTensor(
                    normalized_rewards)
                ppo_batch["old_log_probs"] = torch.FloatTensor(
                    action_log_probs)

                ppo_batch = move_to_device(ppo_batch, self.device)

                self.tmp_vars["log_dict"] = self.train_step(ppo_batch)

                self.callback_handler.fire_event(Events.STEP_BEGIN)
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()
                self.callback_handler.fire_event(Events.STEP_END)

            self.callback_handler.fire_event(Events.BATCH_END)
            self.replay_buffer.clear()
            self.global_step_count += 1
            self.local_step_count += 1
Beispiel #4
0
    def __init__(self, config: DictConfig, model: FlyModel, name: str = "task1", *args, **kwargs):
        """
        Args:
            config: FlyConfig dictionary
            model: must be FlyModel
            dataloader_fn: a Callable function which returns dataloaders
        """
        logger.info("TrainerLoop is initializing!")
        if not isinstance(model, FlyModel):
            logger.warn("model is not defined as FlyModel")
        self.config = config
        self.model = model
        self.name = name

        # class properties
        self.rank = None
        self.local_rank = None
        self.node_rank = None
        self.world_size = None
        self.distributed_training = None
        self.device = None
        self.fp16 = config.fp16
        self.gradient_accumulation_batches = config.gradient_accumulation_batches
        self.callback_handler = None
        self.optimizers = []
        self.schedulers = []

        self.init_distributed_environment()

        # Model is sent to GPU or CPU
        self.init_device()
        # self.optimizers, self.schedulers = self.configure_optimizers()

        self.model = move_to_device(self.model, self.device)
        self.model.device = self.device
        self.init_fp16()

        if self.distributed_training:
            self.init_distributed_model(self.model)

        # make sure the model has access to trainer info
        self.model.set_trainer(self)

        self.callback_handler = CallbackHandler(config,
                                                trainer=self,
                                                callbacks=[],
                                                verbose=config.logging.level == "DEBUG")

        # Configure all callbacks
        self.configure_callbacks()
        self.callback_handler.fire_event(Events.INITIALIZE)
Beispiel #5
0
    def test(self):
        # Start Testing
        self.model.eval()
        self.model.reset_evaluation_metrics()
        self.callback_handler.fire_event(Events.TEST_BEGIN)
        # No gradient is needed for test
        with torch.no_grad():
            pbar = tqdm.tqdm(self.test_dataloader)
            pbar.mininterval = 2.0
            for batch in pbar:
                # send to cuda device
                batch = move_to_device(batch, self.device)
                self.model.predict(batch)

        self.callback_handler.fire_event(Events.TEST_END)
Beispiel #6
0
    def validate(self):
        # Start Validation
        self.model.eval()
        self.model.reset_evaluation_metrics()
        self.callback_handler.fire_event(Events.VALIDATE_BEGIN)
        # No gradient is needed for validation
        with torch.no_grad():
            pbar = tqdm.tqdm(self.validation_dataloader)
            pbar.mininterval = 2.0
            for batch in pbar:
                # send to cuda device
                batch = move_to_device(batch, self.device)
                self.model.predict(batch)

        self.callback_handler.fire_event(Events.VALIDATE_END)
Beispiel #7
0
    def train(
        self,
        config,
        train_dataloader,
        validation_dataloader=None,
        test_dataloader=None,
        configure_optimizers=True,
        stage_name: str = "Stage1",
        *args,
        **kwargs,
    ):
        self.config = config
        self.stage_name = stage_name

        # Model is sent to GPU or CPU
        self.init_device(config)
        # self.optimizers, self.schedulers = self.configure_optimizers()

        self.gradient_accumulation_batches = config.gradient_accumulation_batches
        self.max_gradient_norm = config.optimization.max_gradient_norm
        self.fp16 = config.fp16
        self.model = move_to_device(self.model, self.device)
        self.model.device = self.device
        self.init_fp16(config)

        if self.distributed_training:
            self.init_distributed_model(self.model)

        self.total_num_update_steps = 0
        self.total_num_batches = 0
        self.total_num_epochs = 0
        self.epoch_num_batches = 0
        self.global_batch_count = 0
        self.global_step_count = 0
        self.epochs_trained = 0
        self.local_step_count = 0

        self.train_dataloader = train_dataloader
        self.validation_dataloader = validation_dataloader
        self.test_dataloader = test_dataloader

        self.init_training_constants(config)

        if configure_optimizers or len(self.optimizers) == 0:
            self.configure_optimizers(config, self.total_num_update_steps)

        # Training begins
        self.callback_handler.fire_event(Events.TRAIN_BEGIN)

        while True:
            self.callback_handler.fire_event(Events.EPOCH_BEGIN)
            self.train_epoch()
            self.callback_handler.fire_event(Events.EPOCH_END)
            self.epochs_trained += 1

            if self.training_in_epoch:
                if self.epochs_trained >= self.total_num_epochs:
                    break
            else:
                if self.global_step_count < self.total_num_update_steps:
                    continue
                else:
                    break

        # Training ends
        self.callback_handler.fire_event(Events.TRAIN_END)
Beispiel #8
0
    def collect_samples(self, batch):
        """Generate samples, collect rewards, and update replay buffer"""

        num_human_demos = int(len(batch) * self.mix_human_demo_ratio)
        # num_generations = int(len(batch) * (1 - self.mix_human_demo_ratio))

        self.mix_human_demo_ratio *= self.mix_human_demo_ratio_decay

        actual_sample_size = min(num_human_demos, self.sample_batch_size)
        for i in range(0, num_human_demos, actual_sample_size):
            # Update Buffer for Human Demos
            human_demos_batch = batch[i:i + actual_sample_size]

            # collect human demos log probs
            human_demos_batch_collated = self.collate_fn(human_demos_batch)
            human_demos_batch_collated = move_to_device(
                human_demos_batch_collated, self.device)
            log_probs = self.model.compute_log_probs(
                human_demos_batch_collated)["log_probs"]
            human_log_probs = log_probs.tolist()

            human_tokens = [
                item["target_token_ids"] for item in human_demos_batch
            ]

            if self.config.text_ppo.constant_human_demo_reward:
                rewards = np.ones((len(human_demos_batch))) * 2.0

                self.replay_buffer.update_batch(
                    states=human_demos_batch,
                    actions=human_tokens,
                    action_log_probs=human_log_probs,
                    rewards=rewards,
                    normalize_reward=False)
            else:
                results = {}
                results["tokens"] = human_tokens
                rewards = self.reward_func(human_demos_batch_collated, results)
                self.replay_buffer.update_batch(
                    states=human_demos_batch,
                    actions=human_tokens,
                    action_log_probs=human_log_probs,
                    rewards=rewards)

        # Update Buffer for Generations
        actual_sample_size = min(
            len(batch) - num_human_demos, self.sample_batch_size)
        for i in range(num_human_demos, len(batch), actual_sample_size):
            sample_batch = batch[i:i + actual_sample_size]
            sample_batch_collated = self.collate_fn(sample_batch)
            sample_batch_collated = move_to_device(sample_batch_collated,
                                                   self.device)
            results = self.decoder_generate(sample_batch_collated)

            # TODO: Consider num_return_sequences
            results["tokens"] = [item[0] for item in results["tokens"]]

            if self.recompute_log_probs:
                for i in range(len(batch)):
                    sample_batch_collated["target_token_ids"] = pad_sequence(
                        results["tokens"],
                        batch_first=True,
                        padding_value=self.pad_token_id).to(self.device)

                log_probs = self.model.compute_log_probs(
                    sample_batch_collated)["log_probs"]
                results["log_probs"] = log_probs.tolist()
            else:
                results["log_probs"] = [
                    item[0] for item in results["log_probs"]
                ]

            rewards = self.reward_func(sample_batch_collated, results)

            self.replay_buffer.update_batch(
                states=sample_batch,
                actions=results["tokens"],
                action_log_probs=results["log_probs"],
                rewards=rewards)
Beispiel #9
0
    def __init__(
        self,
        config: DictConfig,
        model: FlyModel,
        train_dataloader_fn: Callable,
        valid_dataloader_fn: Callable = None,
        test_dataloader_fn: Callable = None
    ):
        """
        Args:
            config: FlyConfig dictionary
            model: must be FlyModel
            dataloader_fn: a Callable function which returns dataloaders
        """
        assert isinstance(model, FlyModel)
        self.config = config
        self.model = model

        # For distributed
        self.rank = int(os.environ.get("RANK", 0))
        self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.distributed_training = (self.world_size > 1)

        if self.distributed_training and not torch.distributed.is_initialized():
            torch.distributed.init_process_group(backend='nccl', init_method='env://')
            assert torch.distributed.is_initialized()

        if self.distributed_training and not torch.distributed.is_initialized():
            self.node_rank = os.environ.get("NODE_RANK", "N/A")
            logger.info(
                f"Initialized Rank:{torch.distributed.get_rank()} Locak-rank: {self.local_rank} on Node:{self.node_rank} Node-name:{socket.gethostname()}"
            )

        logger.info("TrainerLoop is initializing!")

        # set cuda device
        if config.training.num_gpus_per_node > 0:
            torch.cuda.set_device(self.local_rank)
            self.device = torch.device("cuda", self.local_rank)
        else:
            self.device = torch.device("cpu")

        # Setup the dataloders
        self.train_dataloader = train_dataloader_fn() if train_dataloader_fn else None
        # only rank 0 can setup validation and test dataloder
        if self.rank == 0:
            self.validation_dataloader: Iterable = valid_dataloader_fn() if valid_dataloader_fn else None
            self.test_dataloader = test_dataloader_fn() if test_dataloader_fn else None

        # Setup callback handler
        self.callback_handler = CallbackHandler(
            config, trainer=self, callbacks=[], verbose=config.training.logging.level == "DEBUG"
        )

        # constants
        self.fp16 = config.training.fp16
        self.gradient_accumulation_batches = config.training.gradient_accumulation_batches

        self.setup_training_constants()

        # local variables
        self.global_batch_count = 0
        self.global_step_count = 0
        self.epochs_trained = 0
        self.local_step_count = 0

        # Configure optimizers
        self.optimizers, self.schedulers = self.model.configure_optimizers(self.total_num_update_steps)
        self.optimizers, self.schedulers = self.configure_optimizers()

        # Model is sent to GPU or CPU
        self.model = move_to_device(self.model, self.device)

        # Mixed-Precision
        if self.fp16:
            if self.config.training.num_gpus_per_node == 0:
                raise NotImplementedError("For mixed precision training, you need to use GPU!")
            self.configure_fp16()

        # Distributed Training
        if self.world_size > 1:
            self.configure_ddp()

        # Configure all callbacks
        self.configure_callbacks()
        self.callback_handler.fire_event(Events.INITIALIZE)

        # make sure the model has access to trainer info
        self.model.set_trainer(self)