def to_ort_model(self, model, config, args):
        model_desc = self.gpt2_model_description(config.n_head,
                                                 config.vocab_size,
                                                 config.n_embd, config.n_layer,
                                                 config.n_ctx,
                                                 args.per_gpu_train_batch_size)
        learning_rate_description = self.ort_trainer_learning_rate_description(
        )

        def map_optimizer_attributes(name):
            no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
            no_decay = False
            for no_decay_key in no_decay_keys:
                if no_decay_key in name:
                    no_decay = True
                    break
            if no_decay:
                return {
                    "alpha": 0.9,
                    "beta": 0.999,
                    "lambda": 0.0,
                    "epsilon": args.adam_epsilon
                }
            else:
                return {
                    "alpha": 0.9,
                    "beta": 0.999,
                    "lambda": args.weight_decay,
                    "epsilon": args.adam_epsilon
                }

        from onnxruntime.capi._pybind_state import set_cuda_device_id, set_arena_extend_strategy, ArenaExtendStrategy
        set_arena_extend_strategy(ArenaExtendStrategy.kSameAsRequested)
        set_cuda_device_id(self.args.local_rank)

        model = ORTTrainer(
            model,
            None,
            model_desc,
            "AdamOptimizer",
            map_optimizer_attributes,
            learning_rate_description,
            args.device,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            world_rank=self.args.world_rank,
            world_size=self.args.world_size,
            use_mixed_precision=self.args.fp16,
            allreduce_post_accumulation=True,
            _opset_version=12)

        logger.info("****************************Model converted to ORT")
        return model
Esempio n. 2
0
def setup_onnxruntime_with_mpi(args):
    '''
    from mpi4py import MPI
    comm = MPI.COMM_WORLD

    has_aml = 'AZ_BATCH_MASTER_NODE' in os.environ.keys() or 'AZ_BATCHAI_MPI_MASTER_NODE' in os.environ.keys()
    if not has_aml:
        print('Detected local run')
        args.local_rank = comm.Get_rank() % torch.cuda.device_count()
        args.world_rank = comm.Get_rank()
        args.world_size = comm.Get_size()

        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        args.n_gpu = 1

    else:
        print('Detected Azure batch run')
        set_environment_variables_for_nccl_backend(get_local_size() == get_global_size(), IB = args.use_ib)
        args.local_rank = get_local_rank()
        args.local_size = get_local_size()
        args.world_rank = get_world_rank()
        args.world_size = get_global_size()

        print('Local rank: {}'.format(args.local_rank))
        print('Local size: {}'.format(args.local_size))
        print('World rank: {}'.format(args.world_rank))
        print('World size: {}'.format(args.world_size))
        print('CUDA device: {}'.format(args.local_rank))

        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        args.n_gpu = 1

        torch.distributed.init_process_group(backend='nccl')
    '''

    #device = torch.device("cuda", get_local_rank())
    device = torch.device("cuda", args.distributed_rank)

    from onnxruntime.capi._pybind_state import set_cuda_device_id
    #set_cuda_device_id(get_local_rank())
    set_cuda_device_id(args.distributed_rank)

    from onnxruntime.capi._pybind_state import set_arena_extend_strategy, ArenaExtendStrategy
    set_arena_extend_strategy(ArenaExtendStrategy.kSameAsRequested)

    return device
Esempio n. 3
0
    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()

        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (self.args.max_steps //
                                (len(train_dataloader) //
                                 self.args.gradient_accumulation_steps) + 1)
        else:
            t_total = int(
                len(train_dataloader) //
                self.args.gradient_accumulation_steps *
                self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        config = self.model.config
        model_desc = self.gpt2_model_description(
            config.n_head, config.vocab_size, config.n_embd, config.n_layer,
            config.n_ctx, self.args.per_gpu_train_batch_size)

        from onnxruntime.capi._pybind_state import set_arena_extend_strategy, ArenaExtendStrategy
        set_arena_extend_strategy(ArenaExtendStrategy.kSameAsRequested)

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

        optim_config = optim.AdamConfig(params=[{
            'params':
            [n for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'lambda_coef':
            0.0
        }],
                                        lr=self.args.learning_rate,
                                        alpha=0.9,
                                        beta=0.999,
                                        lambda_coef=self.args.weight_decay,
                                        epsilon=self.args.adam_epsilon)

        warmup = self.args.warmup_steps / t_total
        lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(
            total_steps=t_total, warmup=warmup)
        loss_scaler = amp.DynamicLossScaler(
            automatic_update=True,
            loss_scale=float(1 << 20),
            up_scale_window=2000,
            min_loss_scale=1.0,
            max_loss_scale=float(1 << 24)) if self.args.fp16 else None

        opts = orttrainer.ORTTrainerOptions({
            'device': {
                'id': str(self.args.device)
            },
            'distributed': {
                'world_rank': self.args.world_rank,
                'world_size': self.args.world_size,
                'local_rank': self.args.local_rank,
                'allreduce_post_accumulation': True
            },
            'mixed_precision': {
                'enabled': self.args.fp16,
                'loss_scaler': loss_scaler
            },
            'batch': {
                'gradient_accumulation_steps':
                self.args.gradient_accumulation_steps
            },
            'lr_scheduler': lr_scheduler
        })

        self.ort_model = orttrainer.ORTTrainer(self.model,
                                               model_desc,
                                               optim_config,
                                               None,
                                               options=opts)

        logger.info("****************************Model converted to ORT")
        model = self.ort_model

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())

        # Train!
        if self.is_world_master():
            logger.info("***** Running training *****")
            logger.info("  Num examples = %d", len(train_dataloader.dataset))
            logger.info("  Num Epochs = %d", num_train_epochs)
            logger.info("  Instantaneous batch size per GPU = %d",
                        self.args.per_gpu_train_batch_size)
            logger.info(
                "  Total train batch size (w. parallel, distributed & accumulation) = %d",
                self.args.train_batch_size *
                self.args.gradient_accumulation_steps *
                (self.args.world_size if self.args.local_rank != -1 else 1),
            )
            logger.info("  Gradient Accumulation steps = %d",
                        self.args.gradient_accumulation_steps)
            logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = global_step // (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = global_step % (
                    len(train_dataloader) //
                    self.args.gradient_accumulation_steps)

                logger.info(
                    "  Continuing training from checkpoint, will skip to saved global_step"
                )
                logger.info("  Continuing training from epoch %d",
                            epochs_trained)
                logger.info("  Continuing training from global step %d",
                            global_step)
                logger.info(
                    "  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)
            except ValueError:
                global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        global_batch_train_start = time.time()

        train_iterator = trange(
            epochs_trained,
            int(num_train_epochs),
            desc="Epoch",
            disable=self.args.local_rank not in [-1, 0],
        )
        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=self.args.local_rank not in [-1, 0])
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                if len(inputs['input_ids']
                       ) < self.args.per_gpu_train_batch_size:
                    # skip incomplete batch
                    logger.info('Skipping incomplete batch...')
                    continue

                tr_loss += self._training_step(model, inputs)
                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                        # last step in epoch but step is always smaller than gradient_accumulation_steps
                        len(epoch_iterator) <=
                        self.args.gradient_accumulation_steps and
                    (step + 1) == len(epoch_iterator)):

                    global_step += 1
                    global_batch_train_duration = time.time(
                    ) - global_batch_train_start
                    global_batch_train_start = time.time()

                    if self.args.local_rank in [-1, 0]:
                        if (self.args.logging_steps > 0
                                and global_step % self.args.logging_steps
                                == 0) or (global_step == 1
                                          and self.args.logging_first_step):
                            logs = {}
                            loss_avg = (tr_loss - logging_loss) / (
                                self.args.logging_steps *
                                self.args.gradient_accumulation_steps)
                            logs["learning_rate"] = lr_scheduler.get_last_lr(
                            )[0]
                            logs["loss"] = loss_avg.item()
                            logs["global_step"] = global_step
                            logs[
                                "global_step_time"] = global_batch_train_duration
                            logging_loss = tr_loss.clone()

                            if self.tb_writer:
                                for k, v in logs.items():
                                    self.tb_writer.add_scalar(
                                        k, v, global_step)
                                    run.log(k, v)
                            epoch_iterator.write(
                                json.dumps({
                                    **logs,
                                    **{
                                        "step": global_step
                                    }
                                }))

                        if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                            # In all cases (even distributed/parallel), self.model is always a reference
                            # to the model we want to save.
                            if hasattr(model, "module"):
                                assert model.module is self.ort_model
                            else:
                                assert model is self.ort_model
                            # Save model checkpoint
                            output_dir = os.path.join(
                                self.args.output_dir,
                                f"{PREFIX_CHECKPOINT_DIR}-{global_step}")
                            self.save_model(output_dir)
                            self._rotate_checkpoints()

                if self.args.max_steps > 0 and global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and global_step > self.args.max_steps:
                train_iterator.close()
                break

        if self.tb_writer:
            self.tb_writer.close()
        self.update_torch_model()

        logger.info(
            "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n"
        )
        return TrainOutput(global_step, tr_loss / global_step)