Пример #1
0
class TrainerLoop:
    def __init__(self,
                 config: DictConfig,
                 model: FlyModel,
                 train_dataloader_fn: Callable,
                 valid_dataloader_fn: Callable = None,
                 test_dataloader_fn: Callable = None):
        """
        Args:
            config: FlyConfig dictionary
            model: must be FlyModel
            dataloader_fn: a Callable function which returns dataloaders
        """
        assert isinstance(model, FlyModel)

        self.config = config
        self.rank, self.local_rank = get_rank()

        # Distributed
        if self.config.training.num_gpus_per_node > 1:
            # Init distributed
            # TODO: multi-node multi-gpu training
            torch.distributed.init_process_group(
                backend="nccl",
                rank=self.rank,
                world_size=self.config.training.num_gpus_per_node * 1)

        # configure distributed training
        self.model = model

        self.train_dataloader = train_dataloader_fn(config)
        self.validation_dataloader: Iterable = valid_dataloader_fn(
            config) if valid_dataloader_fn else None
        self.test_dataloader = test_dataloader_fn(
            config) if test_dataloader_fn else None

        self.callback_handler = CallbackHandler(
            config,
            trainer=self,
            callbacks=[],
            verbose=config.training.logging.level == "DEBUG")

        # constants
        self.gradient_accumulation_steps = config.training.optimization.gradient_accumulation_steps
        self.validation_steps_interval = config.training.validation.steps_interval
        self.fp16 = config.training.optimization.fp16
        self.fp16_opt_level = config.training.optimization.fp16_opt_level
        self.distributed_training = False

        self.total_num_update_steps = int(
            config.training.total_num.update_steps)
        self.total_num_steps = self.total_num_update_steps * int(
            self.gradient_accumulation_steps)
        self.total_num_epochs = int(self.config.training.total_num.epochs)

        # Train in epochs or steps
        if self.total_num_epochs > 0:
            self.training_in_epoch = True
        else:
            if self.total_num_update_steps < 0:
                raise NotImplementedError(
                    "config.training.total_num.updated_steps must be larger than 0"
                )
            self.training_in_epoch = False
            self.total_num_epochs = 1

        # Number of training batches
        if self.training_in_epoch:
            try:
                self.epoch_num_training_steps = len(self.train_dataloader)
                self.total_num_training_steps = self.epoch_num_training_steps * self.total_num_epochs
                self.total_num_update_steps = self.total_num_training_steps // self.gradient_accumulation_steps
            except TypeError:
                # connot set the number of total_num_epoch
                # because it is impossible to know
                logger.error("Cannot get the length of train dtrainer.model")
                raise NotImplementedError(
                    "Please specify the `total_num_epochs` or `total_num_update_steps`!"
                )
        else:
            self.epoch_num_training_steps = self.total_num_update_steps

        # Validation steps interval
        self.validation_after_num_steps = config.training.validation.after_num_steps
        if self.validation_steps_interval < 0:
            self.validation_steps_interval = self.epoch_num_training_steps - 1

        # local variables
        self.global_step_count = 0
        self.epochs_trained = 0
        self.local_step_count = 0

        # set cuda device
        if config.training.num_gpus_per_node > 1:
            torch.cuda.set_device(self.rank)
            self.device = torch.device("cuda", self.local_rank)
        elif config.training.num_gpus_per_node == 1:
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        # Configure optimizers
        self.optimizers, self.schedulers = self.model.configure_optimizers(
            self.total_num_update_steps)
        self.optimizers, self.schedulers = self.configure_optimizers()

        # Model is sent to GPU or CPU
        self.model = move_to_device(self.model, self.device)

        # Mixed-Precision
        if self.fp16 and self.config.training.num_gpus_per_node > 0:
            self.configure_fp16()

        # Distributed Training
        if self.config.training.num_gpus_per_node > 1:
            self.configure_ddp()

        self.configure_callbacks()

        self.log_keys = set()
        self.tmp_vars = {}
        self.callback_handler.fire_event(Events.INITIALIZE)

        # make sure the model has access to trainer info
        self.model.set_trainer(self)

    def update_log_keys(self, keys: List[str]):
        self.log_keys.update(keys)

    def configure_optimizers(self):
        return self.model.configure_optimizers(self.total_num_update_steps)

    def configure_callbacks(self):
        # Callback
        # by default set up LogHandler and Checkpointer
        self.checkpoint_callback = Checkpoint(self.config)
        self.add_callback(self.checkpoint_callback)

        if self.rank == 0:
            self.log_callback = LogHandler(self.config)
            self.add_callback(self.log_callback)

        # No Longer handles the gradient clip here
        # if self.config.training.optimization.max_gradient_norm > 0:
        #     gradient_clip_norm_callback = GradientClipNorm(self.config)
        #     self.add_callback(gradient_clip_norm_callback)

    def configure_fp16(self):
        self.model, self.optimizers = amp.initialize(
            self.model, self.optimizers, opt_level=self.fp16_opt_level)

    def configure_ddp(self):
        # Distributed training (should be after apex fp16 initialization)
        self.distributed_training = True
        self.model = DistributedDataParallel(self.model, delay_allreduce=True)
        # trainer.model = torch.nn.parallel.DistributedDataParallel(
        #     trainer.model, device_ids=[trainer.rank], output_device=trainer.rank, find_unused_parameters=True
        # )

    def train(self):
        # Training begins
        self.callback_handler.fire_event(Events.TRAIN_BEGIN)

        # Start validation at the begining
        if self.rank == 0:
            if self.validation_dataloader is not None:
                self.model.eval()
                self.model.is_training = False
                # BEGIN
                self.callback_handler.fire_event(Events.VALIDATE_BEGIN)

                self.tmp_vars["validate_metrics"] = self.validate()

                self.callback_handler.fire_event(Events.VALIDATE_END)
                self.model.train()
                self.model.is_training = True

        while True:
            self.callback_handler.fire_event(Events.EPOCH_BEGIN)
            self.train_epoch()
            self.callback_handler.fire_event(Events.EPOCH_END)
            self.epochs_trained += 1

            if self.training_in_epoch:
                if self.epochs_trained >= self.total_num_epochs:
                    break
            else:
                if self.global_step_count < self.total_num_steps:
                    continue
                else:
                    break

        # Training ends
        self.callback_handler.fire_event(Events.TRAIN_END)

        # Only rank 0 can run the test dataset
        if self.rank == 0:
            if self.test_dataloader:
                # TODO: Implement test_dataloader
                raise NotImplementedError

    def train_epoch(self):
        self.optimizer = self.optimizers[0]
        self.scheduler = self.schedulers[0]

        self.local_step_count = 0

        for batch in self.train_dataloader:
            self.callback_handler.fire_event(Events.BATCH_BEGIN)

            batch = move_to_device(batch, self.device)
            self.tmp_vars["log_dict"] = self.train_step(batch)

            # Update the model
            if (self.global_step_count +
                    1) % self.gradient_accumulation_steps == 0:
                self.step_update()

            self.callback_handler.fire_event(Events.BATCH_END)

            # Only rank 0 can run the validation dataset
            if self.rank == 0:
                if self.global_step_count > self.validation_after_num_steps and \
                    ((self.global_step_count + 1) % self.validation_steps_interval == 0):

                    if self.validation_dataloader is not None:
                        self.model.eval()
                        self.model.is_training = False
                        # BEGIN
                        self.callback_handler.fire_event(Events.VALIDATE_BEGIN)

                        self.tmp_vars["validate_metrics"] = self.validate()

                        self.callback_handler.fire_event(Events.VALIDATE_END)
                        self.model.train()
                        self.model.is_training = True

            if self.config.training.num_gpus_per_node > 1:
                torch.distributed.barrier()
            if self.global_step_count >= self.total_num_steps:
                break

            self.global_step_count += 1
            self.local_step_count += 1

    def step_update(self):
        self.callback_handler.fire_event(Events.STEP_BEGIN)
        self.optimizer.step()
        self.scheduler.step()
        self.optimizer.zero_grad()
        self.callback_handler.fire_event(Events.STEP_END)

    def train_step(self, batch):
        self.optimizer = self.optimizers[0]
        results = self.model(batch)
        loss = results["loss"]

        if self.gradient_accumulation_steps > 1:
            loss = loss / self.gradient_accumulation_steps

        self.callback_handler.fire_event(Events.BACKWARD_BEGIN)
        self.loss_backward(loss)
        self.callback_handler.fire_event(Events.BACKWARD_END)
        # return the results

        log_dict = {"loss": loss.item() * self.gradient_accumulation_steps}
        log_dict["_lr"] = get_lr(self.optimizer)

        for key in self.log_keys:
            log_dict[key] = get_log_variable(results[key])

        return log_dict

    def loss_backward(self, loss):
        # Loss backward
        if self.fp16:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

    def validate(self):
        # Validation
        self.model.eval()
        # No gradient is needed for validation
        with torch.no_grad():
            for batch in iter(self.validation_dataloader):
                # send to cuda device
                batch = move_to_device(batch, self.device)

                if self.distributed_training:
                    self.model.module.predict(batch)
                else:
                    self.model.predict(batch)
        # END
        # get metrics
        if self.distributed_training:
            metrics = self.model.module.get_metrics(reset=True)
        else:
            metrics = self.model.get_metrics(reset=True)
        return metrics

    def set_model_state(self, model_state_dict):
        if self.distributed_training:
            self.model.module.load_state_dict(model_state_dict)
        else:
            self.model.load_state_dict(model_state_dict)

    def get_model_state(self):
        if self.distributed_training:
            return self.model.module.state_dict()
        else:
            return self.model.state_dict()

    def set_trainer_state(self, trainer_state_dict):
        self.epochs_trained = trainer_state_dict["epochs_trained"]
        self.global_step_count = trainer_state_dict["global_step_count"]
        self.local_step_count = trainer_state_dict["local_step_count"]

        # Resume the training state
        if self.config.training.resume.resume:
            # AMP State
            if self.config.training.optimization.fp16:
                amp.load_state_dict(trainer_state_dict["amp_state_dict"])

            # Scheduler States
            if self.config.training.resume.resume_schedulers:
                for idx, scheduler in enumerate(self.schedulers):
                    try:
                        scheduler.load_state_dict(
                            trainer_state_dict["schedulers_state_dict"][idx])
                    except:
                        if self.rank == 0:
                            logger.warning(
                                f"Cannot Load Scheduler {idx}'s State!")

            # Optimizer States - We cannot load optimizers here because of an amp error
            # if self.config.training.resume.resume_optimizers:
            #     for idx, optimizer in enumerate(self.optimizers):
            #         try:
            #             optimizer.load_state_dict(trainer_state_dict["optimizers_state_dict"][idx])
            #         except:
            #             if self.rank == 0:
            #                 logger.warning(f"Cannot Load Optimizer {idx}'s State!")

            # Random States
            if self.config.training.resume.resume_rng_state:
                torch.set_rng_state(trainer_state_dict["cpu_rng_state"])
                trainer_state_dict["cuda_rng_state"] = trainer_state_dict[
                    "cuda_rng_state"][:torch.cuda.device_count()]
                torch.cuda.set_rng_state_all(
                    trainer_state_dict["cuda_rng_state"])

            # All Callbacks
            for callback in self.callback_handler.callbacks:
                try:
                    callback.load_state_dict(trainer_state_dict[str(
                        type(callback))])
                except:
                    logger.error(f"{type(callback)} seems not to exist!")

    def get_trainer_state(self):
        trainer_state_dict = {
            "epochs_trained":
            self.epochs_trained + 1,
            "global_step_count":
            self.global_step_count,
            "local_step_count":
            self.local_step_count,
            "optimizers_state_dict":
            [optimizer.state_dict() for optimizer in self.optimizers],
            "schedulers_state_dict":
            [scheduler.state_dict() for scheduler in self.schedulers],
            "cpu_rng_state":
            torch.get_rng_state(),
            "cuda_rng_state":
            torch.cuda.get_rng_state_all(),
        }
        # save amp states
        if self.config.training.optimization.fp16:
            trainer_state_dict["amp_state_dict"] = amp.state_dict()

        # All Callbacks
        for callback in self.callback_handler.callbacks:
            trainer_state_dict[str(type(callback))] = callback.state_dict()

        return trainer_state_dict

    def add_callback(self, callback: Callback):
        self.callback_handler.add_callback(callback)