示例#1
0
    def _train(self, rank=0):
        # Optimizer
        self.optimizer = self.configure_optimizer()
        self.rank = rank
        self.master = rank == 0

        # Training Begin
        self.callback_handler.fire_event(Events.TRAIN_BEGIN)
        self.train_loader = cycle_wrapper(self.train_loader)

        # Scheduler
        self.scheduler = self.configure_scheduler()

        for _ in range(self.global_iter_count, self.total_num_iterations):
            if self.local_iter_count == 0:
                self.callback_handler.fire_event(Events.EPOCH_BEGIN)

            # The state should be perserved by torch.get_rng_state
            # However, this solution is not deterministic, but at least it ensures
            # the randomness when loading data
            self.batch = next(self.train_loader)

            # callback handler has access to trainer.batch
            self.callback_handler.fire_event(Events.BATCH_BEGIN)

            self.batch = move_to_device(self.batch, self.device)
            self.batch_results = self.train_iter(self.batch)

            # Update the model
            if (self.global_iter_count + 1) % self.config.training.gradient_accumulation_steps == 0:
                self.callback_handler.fire_event(Events.STEP_BEGIN)
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()
                self.callback_handler.fire_event(Events.STEP_END)

            self.callback_handler.fire_event(Events.BATCH_END)

            # Validation
            if self.master:
                if (self.global_iter_count + 1) % self.config.training.validation_iterations_interval == 0:
                    if self.validation_loader:
                        self.callback_handler.fire_event(Events.VALIDATE_BEGIN)
                        self.model.eval()
                        self.validate_metrics = self.validate()
                        self.callback_handler.fire_event(Events.VALIDATE_END)
                        self.model.train()

            if not self.no_epoch_training and (self.local_iter_count + 1) % self.num_training_batches == 0:
                self.callback_handler.fire_event(Events.EPOCH_END)
                self.epochs_trained += 1
                self.local_iter_count = 0
            else:
                self.local_iter_count += 1

            # Post
            self.global_iter_count += 1

        self.callback_handler.fire_event(Events.TRAIN_END)
        return {}
示例#2
0
    def train_epoch(self):
        self.optimizer = self.optimizers[0]
        self.scheduler = self.schedulers[0]

        self.local_step_count = 0

        iter_train_dataloader = iter(self.train_dataloader)

        while self.global_step_count < self.total_num_update_steps:
            self.callback_handler.fire_event(Events.BATCH_BEGIN)

            # Collect samples
            buffer_count = 0
            while buffer_count < self.ppo_buffer_size:
                try:
                    batch = next(iter_train_dataloader)
                except StopIteration:
                    iter_train_dataloader = iter(self.train_dataloader)
                    batch = next(iter_train_dataloader)

                self.collect_samples(batch)
                buffer_count += len(batch)

            # Train all samples in the buffer
            for mini_batch in self.replay_buffer.iterate_sample(
                    self.ppo_mini_batch_size):

                # (state, action, action_log_prob, reward, normalized_reward)
                states, actions, action_log_probs, rewards, normalized_rewards = zip(
                    *mini_batch)

                for i in range(len(states)):
                    states[i]["target_token_ids"] = actions[i]

                ppo_batch = self.collate_fn(states)
                ppo_batch["normalized_rewards"] = torch.LongTensor(
                    normalized_rewards)
                ppo_batch["old_log_probs"] = torch.FloatTensor(
                    action_log_probs)

                ppo_batch = move_to_device(ppo_batch, self.device)

                self.tmp_vars["log_dict"] = self.train_step(ppo_batch)

                self.callback_handler.fire_event(Events.STEP_BEGIN)
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()
                self.callback_handler.fire_event(Events.STEP_END)

            self.callback_handler.fire_event(Events.BATCH_END)
            self.replay_buffer.clear()
            self.global_step_count += 1
            self.local_step_count += 1
示例#3
0
 def validate(self):
     self.model.eval()
     for batch in self.validation_loader:
         # send to cuda device
         batch = move_to_device(batch, self.device)
         with torch.no_grad():
             if self.config.training.num_gpus_per_node > 1:
                 self.model.module.predict(batch)
             else:
                 self.model.predict(batch)
     # get metrics
     if self.config.training.num_gpus_per_node > 1:
         metrics = self.model.module.get_metrics(reset=True)
     else:
         metrics = self.model.get_metrics(reset=True)
     return metrics
示例#4
0
    def collect_samples(self, batch):
        """Generate samples, collect rewards, and update replay buffer"""

        num_human_demos = int(len(batch) * self.mix_human_demo_ratio)
        # num_generations = int(len(batch) * (1 - self.mix_human_demo_ratio))

        self.mix_human_demo_ratio *= self.mix_human_demo_ratio_decay

        for i in range(0, num_human_demos, self.sample_batch_size):
            # Update Buffer for Human Demos
            human_demos_batch = batch[i:i + self.sample_batch_size]

            human_log_probs = [torch.zeros((item["source_token_ids"].shape[0])) for item in human_demos_batch]
            human_tokens = [item["source_token_ids"] for item in human_demos_batch]

            if self.config.text_ppo.constant_human_demo_reward:
                rewards = np.ones((len(human_demos_batch))) * 2.0

                self.replay_buffer.update_batch(
                    states=human_demos_batch,
                    actions=human_tokens,
                    action_log_probs=human_log_probs,
                    rewards=rewards,
                    normalize_reward=False
                )
            else:
                results = {}
                results["tokens"] = [item["target_token_ids"] for item in human_demos_batch]
                rewards = self.reward_func(human_demos_batch, results, is_human_demo=True)
                self.replay_buffer.update_batch(
                    states=human_demos_batch, actions=human_tokens, action_log_probs=human_log_probs, rewards=rewards
                )
        # Update Buffer for Generations
        for i in range(num_human_demos, len(batch), self.sample_batch_size):
            sample_batch = batch[i:i + self.sample_batch_size]
            sample_batch_collated = self.collator.sample_collate(sample_batch)
            sample_batch_collated = move_to_device(sample_batch_collated, self.device)
            results = self.decoder_generate(sample_batch_collated)

            # TODO: Conside num_return_sequences
            results["tokens"] = [item[0] for item in results["tokens"]]
            results["log_probs"] = [item[0] for item in results["log_probs"]]
            rewards = self.reward_func(sample_batch, results, is_human_demo=False)

            self.replay_buffer.update_batch(
                states=sample_batch, actions=results["tokens"], action_log_probs=results["log_probs"], rewards=rewards
            )
示例#5
0
    def train_epoch(self):
        self.optimizer = self.optimizers[0]
        self.scheduler = self.schedulers[0]

        self.local_step_count = 0

        for batch in self.train_dataloader:
            self.callback_handler.fire_event(Events.BATCH_BEGIN)

            batch = move_to_device(batch, self.device)
            self.tmp_vars["log_dict"] = self.train_step(batch)

            # Update the model
            if (self.global_step_count +
                    1) % self.gradient_accumulation_steps == 0:
                self.step_update()

            self.callback_handler.fire_event(Events.BATCH_END)

            # Only rank 0 can run the validation dataset
            if self.rank == 0:
                if self.global_step_count > self.validation_after_num_steps and \
                    ((self.global_step_count + 1) % self.validation_steps_interval == 0):

                    if self.validation_dataloader is not None:
                        self.model.eval()
                        self.model.is_training = False
                        # BEGIN
                        self.callback_handler.fire_event(Events.VALIDATE_BEGIN)

                        self.tmp_vars["validate_metrics"] = self.validate()

                        self.callback_handler.fire_event(Events.VALIDATE_END)
                        self.model.train()
                        self.model.is_training = True

            if self.config.training.num_gpus_per_node > 1:
                torch.distributed.barrier()
            if self.global_step_count >= self.total_num_steps:
                break

            self.global_step_count += 1
            self.local_step_count += 1
示例#6
0
    def validate(self):
        # Validation
        self.model.eval()
        # No gradient is needed for validation
        with torch.no_grad():
            for batch in iter(self.validation_dataloader):
                # send to cuda device
                batch = move_to_device(batch, self.device)

                if self.distributed_training:
                    self.model.module.predict(batch)
                else:
                    self.model.predict(batch)
        # END
        # get metrics
        if self.distributed_training:
            metrics = self.model.module.get_metrics(reset=True)
        else:
            metrics = self.model.get_metrics(reset=True)
        return metrics
    def train_epoch(self):
        self.D_optimizer = self.optimizers[0]
        self.G_optimizer = self.optimizers[1]
        self.D_scheduler = self.schedulers[0]
        self.G_scheduler = self.schedulers[1]

        prop_decay = self.mix_human_demo_ratio / self.total_num_update_steps

        # total counts
        self.local_step_count = 0
        self.generator_step_count = 0
        self.discriminator_step_count = 0

        iter_train_dataloader = iter(self.train_dataloader)

        while self.global_step_count < self.total_num_update_steps:
            self.callback_handler.fire_event(Events.BATCH_BEGIN)

            # Collect samples
            buffer_count = 0
            while buffer_count < self.ppo_buffer_size:
                try:
                    batch = next(iter_train_dataloader)
                except StopIteration:
                    iter_train_dataloader = iter(self.train_dataloader)
                    self.epochs_trained += 1
                    logger.info(f"{self.epochs_trained} has finished!")
                    batch = next(iter_train_dataloader)

                self.collect_samples(batch)
                buffer_count += len(batch)

            # Discriminator Warmup
            if not self.done_discriminator_pretrain:
                for mini_batch in self.replay_buffer.iterate_sample(
                        self.ppo_mini_batch_size):
                    # (state, action, action_log_prob, reward, normalized_reward)
                    states, actions, action_log_probs, rewards, normalized_rewards = zip(
                        *mini_batch)
                    self.tmp_vars["log_dict"] = self.train_discriminator_step(
                        states, actions)

                    if (self.discriminator_step_count +
                            1) % self.gradient_accumulation_steps == 0:
                        # self.callback_handler.fire_event(Events.STEP_BEGIN)
                        self.D_optimizer.step()
                        self.D_scheduler.step()
                        self.D_optimizer.zero_grad()
                        # self.callback_handler.fire_event(Events.STEP_END)

                        self.discriminator_step_count += 1

                    if self.discriminator_step_count >= self.discriminator_pretrain_steps:
                        self.done_discriminator_pretrain = True
                        break
            else:
                "Generator Training"
                for _ in range(self.ppo_epoch):
                    # Train the Generator
                    for mini_batch in self.replay_buffer.iterate_sample(
                            self.ppo_mini_batch_size):

                        # (state, action, action_log_prob, reward, normalized_reward)
                        states, actions, action_log_probs, rewards, normalized_rewards = zip(
                            *mini_batch)

                        ppo_batch = self.collate_fn(states)

                        ppo_batch["target_token_ids"] = pad_sequence(
                            actions,
                            batch_first=True,
                            padding_value=self.pad_token_id)
                        ppo_batch["normalized_rewards"] = torch.LongTensor(
                            normalized_rewards)
                        ppo_batch["old_log_probs"] = torch.FloatTensor(
                            action_log_probs)

                        ppo_batch = move_to_device(ppo_batch, self.device)

                        self.tmp_vars["log_dict"] = self.train_generator_step(
                            ppo_batch)

                        if (self.generator_step_count +
                                1) % self.gradient_accumulation_steps == 0:
                            # self.callback_handler.fire_event(Events.STEP_BEGIN)
                            self.G_optimizer.step()
                            self.G_scheduler.step()
                            self.G_optimizer.zero_grad()
                            # self.callback_handler.fire_event(Events.STEP_END)

                            self.generator_step_count += 1

                "Discriminator Training"
                for mini_batch in self.replay_buffer.iterate_sample(
                        self.ppo_mini_batch_size):
                    states, actions, action_log_probs, rewards, normalized_rewards = zip(
                        *mini_batch)
                    log_dict = self.train_discriminator_step(states, actions)
                    self.tmp_vars["log_dict"].update(log_dict)

                    if (self.discriminator_step_count +
                            1) % self.gradient_accumulation_steps == 0:
                        # self.callback_handler.fire_event(Events.STEP_BEGIN)
                        self.D_optimizer.step()
                        self.D_scheduler.step()
                        self.D_optimizer.zero_grad()
                        # self.callback_handler.fire_event(Events.STEP_END)

                        self.discriminator_step_count += 1

            # update human mix_human_ratio
            self.mix_human_demo_ratio -= prop_decay
            self.tmp_vars["log_dict"][
                "mix_human_demo_ratio"] = self.mix_human_demo_ratio

            self.callback_handler.fire_event(Events.BATCH_END)
            self.replay_buffer.clear()

            # Only rank 0 can run the validation dataset
            if self.rank == 0:
                if (self.global_step_count +
                        1) % self.validation_steps_interval == 0:
                    if not self.validation_dataloader is None:
                        self.model.eval()
                        # BEGIN
                        self.callback_handler.fire_event(Events.VALIDATE_BEGIN)

                        self.tmp_vars["validate_metrics"] = self.validate()

                        self.callback_handler.fire_event(Events.VALIDATE_END)
                        self.model.train()

            self.global_step_count += 1
            self.local_step_count += 1
    def collect_samples(self, batch):
        """Generate samples, collect rewards, and update replay buffer"""

        num_human_demos = int(len(batch) * self.mix_human_demo_ratio)
        # num_generations = int(len(batch) * (1 - self.mix_human_demo_ratio))

        actual_sample_size = min(num_human_demos, self.sample_batch_size)
        if actual_sample_size > 0:
            for i in range(0, num_human_demos, actual_sample_size):
                # Update Buffer for Human Demos
                human_demos_batch = batch[i:i + actual_sample_size]

                # collect human demos log probs
                human_demos_batch_collated = self.collate_fn(human_demos_batch)
                human_demos_batch_collated = move_to_device(
                    human_demos_batch_collated, self.device)
                log_probs = self.model.generator.compute_log_probs(
                    human_demos_batch_collated)["log_probs"]
                human_log_probs = log_probs.tolist()

                human_tokens = [
                    item["target_token_ids"] for item in human_demos_batch
                ]

                if self.constant_human_demo_reward:
                    rewards = np.ones((len(human_demos_batch))) * 2.0
                    self.replay_buffer.update_batch(
                        states=human_demos_batch,
                        actions=human_tokens,
                        action_log_probs=human_log_probs,
                        rewards=rewards,
                        normalize_reward=False)
                else:
                    results = {}
                    results["tokens"] = human_tokens
                    rewards = self.reward_func.get_reward(
                        human_demos_batch, results)
                    self.replay_buffer.update_batch(
                        states=human_demos_batch,
                        actions=human_tokens,
                        action_log_probs=human_log_probs,
                        rewards=rewards)

        # Update Buffer for Generations
        actual_sample_size = min(
            len(batch) - num_human_demos, self.sample_batch_size)
        for i in range(num_human_demos, len(batch), actual_sample_size):
            sample_batch = batch[i:i + actual_sample_size]
            sample_batch_collated = self.collate_fn(sample_batch)
            sample_batch_collated = move_to_device(sample_batch_collated,
                                                   self.device)
            results = self.decoder_generate(sample_batch_collated)

            # TODO: Consider num_return_sequences
            results["tokens"] = [item[0] for item in results["tokens"]]

            # recompute the log probs for better precision
            if self.recompute_log_probs:
                temp_target_token_ids = sample_batch_collated[
                    "target_token_ids"]
                sample_batch_collated["target_token_ids"] = pad_sequence(
                    results["tokens"],
                    batch_first=True,
                    padding_value=self.pad_token_id).to(self.device)

                log_probs = self.model.generator.compute_log_probs(
                    sample_batch_collated)["log_probs"]
                results["log_probs"] = log_probs.tolist()
                # we switch back the original target_token_ids
                sample_batch_collated[
                    "target_token_ids"] = temp_target_token_ids
            else:
                results["log_probs"] = [
                    item[0] for item in results["log_probs"]
                ]

            rewards = self.reward_func.get_reward(sample_batch, results)

            self.replay_buffer.update_batch(
                states=sample_batch,
                actions=results["tokens"],
                action_log_probs=results["log_probs"],
                rewards=rewards)
示例#9
0
    def __init__(self,
                 config: DictConfig,
                 model: FlyModel,
                 train_dataloader_fn: Callable,
                 valid_dataloader_fn: Callable = None,
                 test_dataloader_fn: Callable = None):
        """
        Args:
            config: FlyConfig dictionary
            model: must be FlyModel
            dataloader_fn: a Callable function which returns dataloaders
        """
        assert isinstance(model, FlyModel)

        self.config = config
        self.rank, self.local_rank = get_rank()

        # Distributed
        if self.config.training.num_gpus_per_node > 1:
            # Init distributed
            # TODO: multi-node multi-gpu training
            torch.distributed.init_process_group(
                backend="nccl",
                rank=self.rank,
                world_size=self.config.training.num_gpus_per_node * 1)

        # configure distributed training
        self.model = model

        self.train_dataloader = train_dataloader_fn(config)
        self.validation_dataloader: Iterable = valid_dataloader_fn(
            config) if valid_dataloader_fn else None
        self.test_dataloader = test_dataloader_fn(
            config) if test_dataloader_fn else None

        self.callback_handler = CallbackHandler(
            config,
            trainer=self,
            callbacks=[],
            verbose=config.training.logging.level == "DEBUG")

        # constants
        self.gradient_accumulation_steps = config.training.optimization.gradient_accumulation_steps
        self.validation_steps_interval = config.training.validation.steps_interval
        self.fp16 = config.training.optimization.fp16
        self.fp16_opt_level = config.training.optimization.fp16_opt_level
        self.distributed_training = False

        self.total_num_update_steps = int(
            config.training.total_num.update_steps)
        self.total_num_steps = self.total_num_update_steps * int(
            self.gradient_accumulation_steps)
        self.total_num_epochs = int(self.config.training.total_num.epochs)

        # Train in epochs or steps
        if self.total_num_epochs > 0:
            self.training_in_epoch = True
        else:
            if self.total_num_update_steps < 0:
                raise NotImplementedError(
                    "config.training.total_num.updated_steps must be larger than 0"
                )
            self.training_in_epoch = False
            self.total_num_epochs = 1

        # Number of training batches
        if self.training_in_epoch:
            try:
                self.epoch_num_training_steps = len(self.train_dataloader)
                self.total_num_training_steps = self.epoch_num_training_steps * self.total_num_epochs
                self.total_num_update_steps = self.total_num_training_steps // self.gradient_accumulation_steps
            except TypeError:
                # connot set the number of total_num_epoch
                # because it is impossible to know
                logger.error("Cannot get the length of train dtrainer.model")
                raise NotImplementedError(
                    "Please specify the `total_num_epochs` or `total_num_update_steps`!"
                )
        else:
            self.epoch_num_training_steps = self.total_num_update_steps

        # Validation steps interval
        self.validation_after_num_steps = config.training.validation.after_num_steps
        if self.validation_steps_interval < 0:
            self.validation_steps_interval = self.epoch_num_training_steps - 1

        # local variables
        self.global_step_count = 0
        self.epochs_trained = 0
        self.local_step_count = 0

        # set cuda device
        if config.training.num_gpus_per_node > 1:
            torch.cuda.set_device(self.rank)
            self.device = torch.device("cuda", self.local_rank)
        elif config.training.num_gpus_per_node == 1:
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        # Configure optimizers
        self.optimizers, self.schedulers = self.model.configure_optimizers(
            self.total_num_update_steps)
        self.optimizers, self.schedulers = self.configure_optimizers()

        # Model is sent to GPU or CPU
        self.model = move_to_device(self.model, self.device)

        # Mixed-Precision
        if self.fp16 and self.config.training.num_gpus_per_node > 0:
            self.configure_fp16()

        # Distributed Training
        if self.config.training.num_gpus_per_node > 1:
            self.configure_ddp()

        self.configure_callbacks()

        self.log_keys = set()
        self.tmp_vars = {}
        self.callback_handler.fire_event(Events.INITIALIZE)

        # make sure the model has access to trainer info
        self.model.set_trainer(self)