def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, use_lamb=True): """Instantiate and load checkpoint into trainer - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model - Loads the checkpoint from directory checkpoint_dir into the trainer - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict and the pytorch model """ seed = 1 torch.manual_seed(seed) set_seed(seed) # PyTorch transformer model setup learning_rate = 0.1 optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load checkpoint into trainer checkpoint_file_name = 'checkpoint*.ortcp' checkpoint_files = glob.glob(os.path.join(checkpoint_dir, checkpoint_file_name)) trainer.load_checkpoint(*checkpoint_files) # run an eval step to innitialize the graph torch.manual_seed(seed) set_seed(seed) data, targets = batcher_fn(train_data, 0) trainer.eval_step(data, targets) return trainer.state_dict(), model
def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental( loss_scaler, legacy_loss_scaler): # Common setup total_steps = 128 device = "cuda" seed = 1 # EXPERIMENTAL IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig(lr=0.001) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'mixed_precision': { 'enabled': True, 'loss_scaler': loss_scaler } }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( optim_config.lr) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, use_mixed_precision=True, loss_scaler=legacy_loss_scaler) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, state_dict_key_name='state_dict', use_lamb=True): learning_rate = 0.1 seed = 1 torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) if 'distributed' in trainer_opts: train_data = next( islice( _chunkify(train_data, trainer_opts['distributed']['world_size']), trainer_opts['distributed']['world_rank'], None)) # run train steps _train(trainer, train_data, batcher_fn) # save current model parameters as a checkpoint if checkpoint_dir: _save(trainer, checkpoint_dir, state_dict_key_name)
def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True, seed=1, learning_rate=0.1): torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load dummy state dummy_init_state = generate_dummy_optim_state(model, optim_config) trainer.load_state_dict(dummy_init_state) # run an eval step to innitialize the graph data, targets = batcher_fn(train_data, 0) trainer.eval_step(data, targets) optimizer_state_dict = trainer.state_dict() del optimizer_state_dict["model"] return dummy_init_state, optimizer_state_dict
def testToyBERTModelGradientAccumulationLegacyExperimental( gradient_accumulation_steps): # Common setup total_steps = 128 device = "cuda" seed = 1 # EXPERIMENTAL IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'batch': { 'gradient_accumulation_steps': gradient_accumulation_steps }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) loss = trainer.train_step(*sample_input) experimental_losses.append(loss.cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( optim_config.lr) legacy_trainer = Legacy_ORTTrainer( model, None, legacy_model_desc, "AdamOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, gradient_accumulation_steps=gradient_accumulation_steps) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input, learning_rate) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def create_orttrainer_and_save_checkpoint_bart( device, trainer_opts, checkpoint_dir, state_dict_key_name="state_dict", use_lamb=True, seed=1, learning_rate=0.1): """Instantiate trainer and save checkpoint for BART. - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model - Loads a dummy optimizer state into the trainer - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict, the expected state dict if present, and the onnx model """ torch.manual_seed(seed) set_seed(seed) ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc = _load_bart_model() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=ort_trainer_opts) # load dummy optimizer state as we are not going to run real training dummy_init_state = generate_dummy_optim_state(model, optim_config) init_state = copy.deepcopy(dummy_init_state) trainer.load_state_dict(dummy_init_state) # run an eval step to innitialize the graph src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc( model_desc, seed=seed) trainer.eval_step(src_tokens, prev_output_tokens, target) # save current model parameters as a checkpoint if checkpoint_dir: if _is_model_parallel_run(ort_trainer_opts): _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=ort_trainer_opts.distributed.world_rank) # save the initial complete model and optimizer states if ort_trainer_opts.distributed.world_rank == 0: init_state["model"] = {"full_precision": dict()} for initializer in model.graph.initializer: init_state["model"]["full_precision"][ initializer.name] = numpy_helper.to_array(initializer) with open( os.path.join(checkpoint_dir, "expected_state_dict.pkl"), "wb") as f: pickle.dump(init_state, f) else: _save(trainer, checkpoint_dir, state_dict_key_name)
def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map): # Common setup total_steps = 128 device = "cuda" seed = 1 # EXPERIMENTAL API torch.manual_seed(seed) onnxruntime.set_seed(seed) model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig( params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False ) opts = orttrainer.ORTTrainerOptions( { "debug": {"deterministic_compute": True}, "device": { "id": device, }, } ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append(trainer.train_step(*sample_input).cpu().item()) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr) legacy_trainer = Legacy_ORTTrainer( model, None, legacy_model_desc, "AdamOptimizer", legacy_optim_map, learning_rate_description, device, _use_deterministic_compute=True, ) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) legacy_sample_input = [*sample_input, learning_rate] legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
def prepare_model(args, device): config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir) # config.num_hidden_layers = 12 if args.force_num_hidden_layers: logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers) config.num_hidden_layers = args.force_num_hidden_layers model = BertForPreTraining(config) if args.init_state_dict is not None: model.load_state_dict(args.init_state_dict) model_desc = bert_model_description(config) lr_scheduler = LinearWarmupLRScheduler(total_steps=int(args.max_steps), warmup=args.warmup_proportion) loss_scaler = amp.DynamicLossScaler() if args.fp16 else None options = orttrainer.ORTTrainerOptions({'batch': { 'gradient_accumulation_steps': args.gradient_accumulation_steps}, 'device': {'id': str(device)}, 'mixed_precision': { 'enabled': args.fp16, 'loss_scaler': loss_scaler}, 'graph_transformer': { 'attn_dropout_recompute': args.attn_dropout_recompute, 'gelu_recompute': args.gelu_recompute, 'transformer_layer_recompute': args.transformer_layer_recompute, }, 'debug': {'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': True}, 'distributed': { 'world_rank': max(0, args.local_rank), 'world_size': args.world_size, 'local_rank': max(0, args.local_rank), 'allreduce_post_accumulation': args.allreduce_post_accumulation, 'deepspeed_zero_optimization': {'stage': args.deepspeed_zero_stage}, 'enable_adasum': False}, 'lr_scheduler': lr_scheduler }) param_optimizer = list(model.named_parameters()) no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] params = [{ 'params': [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}, { 'params': [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}] optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options) return model
def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rates, expected_losses): return # TODO: re-enable after nondeterminism on backend is fixed # Common setup device = "cuda" total_steps = 10 seed = 1 warmup = 0.05 cycles = 0.5 power = 1.0 lr_end = 1e-7 rtol = 1e-3 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Setup LR Schedulers if ( lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler ): lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) else: raise RuntimeError("Invalid lr_scheduler") # Modeling model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig(lr=initial_lr) opts = orttrainer.ORTTrainerOptions( { "debug": {"deterministic_compute": True}, "device": { "id": device, }, "lr_scheduler": lr_scheduler, } ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train losses = [] learning_rates = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) losses.append(trainer.train_step(*sample_input).cpu().item()) learning_rates.append(trainer.options.lr_scheduler.get_last_lr()[0]) # Check output _test_helpers.assert_model_outputs(learning_rates, expected_learning_rates, rtol=rtol) _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol)
def create_initialized_orttrainer(device, trainer_opts, use_lamb=True): seed = 1 torch.manual_seed(seed) set_seed(seed) learning_rate = 1e-10 optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) _train(trainer, train_data, batcher_fn) return trainer
def verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False): expected_optim_state, trainer_state = load_model_optim_state_and_eval( device, opts, use_lamb) trainer_state = split_state_dict(trainer_state) # round about way of checking optimizer states. Save state dicts into temporary folder, read them and aggregate them. with open( os.path.join(checkpoint_dir, 'distributed_state_' + str(world_rank) + '.pkl'), "wb") as f: pickle.dump(trainer_state, f) dist.barrier() if world_rank == 0: num_states = len(glob.glob1(checkpoint_dir, "distributed_state*")) optimizer_states = dict() for rank in range(num_states): rank_state_dict = None with open( os.path.join(checkpoint_dir, 'distributed_state_' + str(rank) + '.pkl'), 'rb') as f: rank_state_dict = pickle.load(f) # collect optimizer states for later comparison since they are sharded aggregate_states(optimizer_states, rank_state_dict['optimizer']) # compare optimizer states optimizer_config = optim.LambConfig( ) if use_lamb else optim.AdamConfig() actual_optim_state = get_optim_state_from_state_dict( optimizer_states, optimizer_config) assert actual_optim_state.keys() == expected_optim_state.keys() for param_name, a_state in actual_optim_state.items(): for k, v in a_state.items(): assert_allclose( v.reshape(expected_optim_state[param_name][k].shape), expected_optim_state[param_name][k], err_msg= f"Optimizer state mismatch for param {param_name}, key {k}" ) dist.barrier() os.remove( os.path.join(checkpoint_dir, 'distributed_state_' + str(world_rank) + '.pkl'))
def create_orttrainer_and_load_checkpoint_bart(device, trainer_opts, checkpoint_dir, use_lamb=True, seed=1, learning_rate=0.1): """Instantiate and load checkpoint into trainer - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model - Loads the checkpoint from directory checkpoint_dir into the trainer - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict, the expected state dict if present, and the onnx model """ torch.manual_seed(seed) set_seed(seed) # model setup optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc = _load_bart_model() trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load checkpoint into trainer checkpoint_file_name = "checkpoint*.ortcp" checkpoint_files = glob.glob( os.path.join(checkpoint_dir, checkpoint_file_name)) trainer.load_checkpoint(*checkpoint_files) # run an eval step to innitialize the graph src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc( model_desc, seed=seed) trainer.eval_step(src_tokens, prev_output_tokens, target) expected_state_dict = None fname = os.path.join(checkpoint_dir, "expected_state_dict.pkl") if os.path.isfile(fname): with open(fname, "rb") as f: expected_state_dict = pickle.load(f) return trainer.state_dict(), expected_state_dict, model
def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, state_dict_key_name="state_dict", use_lamb=True, seed=1, learning_rate=0.1): torch.manual_seed(seed) set_seed(seed) ort_trainer_opts = orttrainer.ORTTrainerOptions(trainer_opts) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=ort_trainer_opts) if "distributed" in trainer_opts: train_data = next( islice( _chunkify(train_data, trainer_opts["distributed"]["world_size"]), trainer_opts["distributed"]["world_rank"], None, )) # run train steps _train(trainer, train_data, batcher_fn) # save current model parameters as a checkpoint if checkpoint_dir: if _is_model_parallel_run(ort_trainer_opts): _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=ort_trainer_opts.distributed.world_rank) else: _save(trainer, checkpoint_dir, state_dict_key_name)
def train(self, model_path: Optional[str] = None): """ Main training entry point. Args: model_path: (Optional) Local path to model if model to train has been instantiated from a local path If present, we will try reloading the optimizer/scheduler states from there. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs config = self.model.config model_desc = self.gpt2_model_description( config.n_head, config.vocab_size, config.n_embd, config.n_layer, config.n_ctx, self.args.per_gpu_train_batch_size) from onnxruntime.capi._pybind_state import set_arena_extend_strategy, ArenaExtendStrategy set_arena_extend_strategy(ArenaExtendStrategy.kSameAsRequested) param_optimizer = list(self.model.named_parameters()) no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] optim_config = optim.AdamConfig(params=[{ 'params': [n for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'lambda_coef': 0.0 }], lr=self.args.learning_rate, alpha=0.9, beta=0.999, lambda_coef=self.args.weight_decay, epsilon=self.args.adam_epsilon) warmup = self.args.warmup_steps / t_total lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler( total_steps=t_total, warmup=warmup) loss_scaler = amp.DynamicLossScaler( automatic_update=True, loss_scale=float(1 << 20), up_scale_window=2000, min_loss_scale=1.0, max_loss_scale=float(1 << 24)) if self.args.fp16 else None opts = orttrainer.ORTTrainerOptions({ 'device': { 'id': str(self.args.device) }, 'distributed': { 'world_rank': self.args.world_rank, 'world_size': self.args.world_size, 'local_rank': self.args.local_rank, 'allreduce_post_accumulation': True }, 'mixed_precision': { 'enabled': self.args.fp16, 'loss_scaler': loss_scaler }, 'batch': { 'gradient_accumulation_steps': self.args.gradient_accumulation_steps }, 'lr_scheduler': lr_scheduler }) self.ort_model = orttrainer.ORTTrainer(self.model, model_desc, optim_config, None, options=opts) logger.info("****************************Model converted to ORT") model = self.ort_model if self.tb_writer is not None: self.tb_writer.add_text("args", self.args.to_json_string()) # Train! if self.is_world_master(): logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataloader.dataset)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size * self.args.gradient_accumulation_steps * (self.args.world_size if self.args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if model_path is not None: # set global_step to global_step of last saved checkpoint from model path try: global_step = int(model_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // ( len(train_dataloader) // self.args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( len(train_dataloader) // self.args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info( " Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: global_step = 0 logger.info(" Starting fine-tuning.") tr_loss = 0.0 logging_loss = 0.0 global_batch_train_start = time.time() train_iterator = trange( epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], ) for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue if len(inputs['input_ids'] ) < self.args.per_gpu_train_batch_size: # skip incomplete batch logger.info('Skipping incomplete batch...') continue tr_loss += self._training_step(model, inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): global_step += 1 global_batch_train_duration = time.time( ) - global_batch_train_start global_batch_train_start = time.time() if self.args.local_rank in [-1, 0]: if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (global_step == 1 and self.args.logging_first_step): logs = {} loss_avg = (tr_loss - logging_loss) / ( self.args.logging_steps * self.args.gradient_accumulation_steps) logs["learning_rate"] = lr_scheduler.get_last_lr( )[0] logs["loss"] = loss_avg.item() logs["global_step"] = global_step logs[ "global_step_time"] = global_batch_train_duration logging_loss = tr_loss.clone() if self.tb_writer: for k, v in logs.items(): self.tb_writer.add_scalar( k, v, global_step) run.log(k, v) epoch_iterator.write( json.dumps({ **logs, **{ "step": global_step } })) if self.args.save_steps > 0 and global_step % self.args.save_steps == 0: # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. if hasattr(model, "module"): assert model.module is self.ort_model else: assert model is self.ort_model # Save model checkpoint output_dir = os.path.join( self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{global_step}") self.save_model(output_dir) self._rotate_checkpoints() if self.args.max_steps > 0 and global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and global_step > self.args.max_steps: train_iterator.close() break if self.tb_writer: self.tb_writer.close() self.update_torch_model() logger.info( "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" ) return TrainOutput(global_step, tr_loss / global_step)
def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, legacy_lr_scheduler): ############################################################################ # These tests require hard-coded values for 'total_steps' and 'initial_lr' # ############################################################################ # Common setup total_steps = 128 device = 'cuda' seed = 1 warmup = 0.05 cycles = 0.5 power = 1. lr_end = 1e-7 # Setup both Experimental and Legacy LR Schedulers before the experimental loop if legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler: legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup) elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler: legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, cycles=cycles) elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler: legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) else: raise RuntimeError("Invalid legacy_lr_scheduler") if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) elif lr_scheduler == optim.lr_scheduler.PolyWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) else: raise RuntimeError("Invalid lr_scheduler") # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.AdamConfig(lr=initial_lr) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, 'lr_scheduler': lr_scheduler }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) experimental_losses.append( trainer.train_step(*sample_input).cpu().item()) assert_allclose(trainer.options.lr_scheduler.get_last_lr()[0], legacy_lr_scheduler(i)) # LEGACY IMPLEMENTATION torch.manual_seed(seed) onnxruntime.set_seed(seed) device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params( initial_lr) legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", None, learning_rate_description, device, _use_deterministic_compute=True, get_lr_this_step=legacy_lr_scheduler) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) leg_loss = legacy_trainer.train_step(*sample_input) legacy_losses.append(leg_loss.cpu().item()) # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses)
total_loss += len(data) * loss.item() return total_loss / (len(data_source) - 1) best_val_loss = float("inf") epochs = 3 # The number of epochs best_model = None model_description = { 'inputs': [('src', ['bptt', 'batch_size']), ('label', ['bptt_x_batch_size'])], 'outputs': [('loss', [], True), ('output', ['bptt', 'batch_size', ntokens])] } optimizer_config = optim.AdamConfig(lr=learning_rate) trainer = ORTTrainer( model, # model model_description, # model description optimizer_config, # optimizer configuration loss_with_flat_output) # loss function for epoch in range(1, epochs + 1): epoch_start_time = time.time() train() val_loss = evaluate(model, val_data) print('-' * 89) print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss)))
model_desc2, optim_config2, options=opts) trainer2.load_state_dict(state_dict) # Evaluate once to get a base loss ckpt_loss = trainer2.eval_step(*sample_input) # Must match as both trainers have the same dict state assert_allclose(loss.cpu(), ckpt_loss.cpu()) loaded_state_dict = trainer2.state_dict() _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict) @pytest.mark.parametrize("optimizer, mixedprecision_enabled", [ (optim.LambConfig(), False), (optim.AdamConfig(), False), (optim.LambConfig(), True), (optim.AdamConfig(), True), ]) def testToyBertLoadOptimState(optimizer, mixedprecision_enabled): # Common setup rtol = 1e-03 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optimizer opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True },
def prepare_model(args, device): config = BertConfig.from_pretrained('bert-base-uncased', cache_dir=args.cache_dir) if args.force_num_hidden_layers: logger.info("Modifying model config with num_hidden_layers to %d", args.force_num_hidden_layers) config.num_hidden_layers = args.force_num_hidden_layers model = BertForPreTraining(config) model_desc = bert_model_description(config) lr_scheduler = PolyWarmupLRScheduler(total_steps=int(args.max_steps)) loss_scaler = amp.DynamicLossScaler() if args.fp16 else None options = orttrainer.ORTTrainerOptions({ 'batch': { 'gradient_accumulation_steps': args.gradient_accumulation_steps }, 'device': { 'id': str(device) }, 'mixed_precision': { 'enabled': args.fp16, 'loss_scaler': loss_scaler }, 'debug': { 'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': True }, 'distributed': { 'allreduce_post_accumulation': True }, 'lr_scheduler': lr_scheduler }) param_optimizer = list(model.named_parameters()) no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] params = [{ 'params': [ n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys) ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }, { 'params': [ n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys) ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }] optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options) return model
num_pipeline_stages = 2 # Compute batch size for micro-batches. n_slice = int(n / num_pipeline_steps) cuda_device = 'cuda:' + str(rank) # Schema used when running the original batch. schema = {'inputs': [('x', ['n', 'd_in']), ('target', ['n'])], 'outputs': [ ('loss', [], True), ('output', ['n', d_out])]} # Actual schema used when running micro-batches. pipeline_schema = {'x': [n_slice, d_in], 'target': [ n_slice], 'output': [n_slice, d_out], 'loss': []} # Describe which axis to slice along for each sliced tensor. sliced_axes = {'x': 0, 'target': 0, 'output': 0} adam_config = optim.AdamConfig(lr=0.1) # # Specify configuration for pipeline parallel training. trainer_config = ORTTrainerOptions({ 'batch': { 'gradient_accumulation_steps': num_pipeline_steps }, 'device': { 'id': cuda_device }, 'distributed': { 'world_size': total_ranks, 'world_rank': rank, 'data_parallel_size': int(total_ranks / num_pipeline_stages), 'horizontal_parallel_size': 1, 'pipeline_parallel': {
def train(self): """ Main training entry point. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler( t_total, self.args.warmup_steps / float(t_total)) loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None device = self.args.device.type device = f"{device}:{self.args.device.index}" if self.args.device.index else f"{device}:0" options = orttrainer.ORTTrainerOptions({ "batch": { "gradient_accumulation_steps": self.args.gradient_accumulation_steps }, "device": { "id": device }, "mixed_precision": { "enabled": self.args.fp16, "loss_scaler": loss_scaler }, "debug": { "deterministic_compute": True, }, "utils": { "grad_norm_clip": False }, "distributed": { # we are running single node multi gpu test. thus world_rank = local_rank # and world_size = self.args.n_gpu "world_rank": max(0, self.args.local_rank), "world_size": int(self.world_size), "local_rank": max(0, self.args.local_rank), "allreduce_post_accumulation": True, }, "lr_scheduler": lr_scheduler, }) param_optimizer = list(self.model.named_parameters()) params = [ { "params": [ n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n ], "weight_decay_mode": 1, }, { "params": [ n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n) ], "weight_decay_mode": 1, }, ] optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) self.model = orttrainer.ORTTrainer(self.model, self.model_desc, optim_config, options=options) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataloader.dataset)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 tr_loss = 0.0 logging_loss = 0.0 train_iterator = trange( epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], ) for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self._training_step(self.model, inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): global_step += 1 if self.args.local_rank in [-1, 0]: if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (global_step == 1 and self.args.logging_first_step): logs = {} if self.args.evaluate_during_training: results = self.evaluate() for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value loss_scalar = (tr_loss - logging_loss ) / self.args.logging_steps logs["loss"] = loss_scalar logging_loss = tr_loss epoch_iterator.write( json.dumps({ **logs, **{ "step": global_step } })) if self.args.max_steps > 0 and global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and global_step > self.args.max_steps: train_iterator.close() break logger.info("\n\nTraining completed. \n\n") return TrainOutput(global_step, tr_loss / global_step)
def train(self): """ Main training entry point. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs if self.use_new_api: lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler( t_total, self.args.warmup_steps / float(t_total)) loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None device = self.args.device.type device = f'{device}:{self.args.device.index}' if self.args.device.index else f'{device}:0' options = orttrainer.ORTTrainerOptions({ 'batch': { 'gradient_accumulation_steps': self.args.gradient_accumulation_steps }, 'device': { 'id': device }, 'mixed_precision': { 'enabled': self.args.fp16, 'loss_scaler': loss_scaler }, 'debug': { 'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': False }, 'distributed': { # we are running single node multi gpu test. thus world_rank = local_rank # and world_size = self.args.n_gpu 'world_rank': max(0, self.args.local_rank), 'world_size': int(self.world_size), 'local_rank': max(0, self.args.local_rank), 'allreduce_post_accumulation': True }, 'lr_scheduler': lr_scheduler }) param_optimizer = list(self.model.named_parameters()) params = [{ 'params': [ n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n ], "weight_decay_mode": 1, }, { 'params': [ n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n) ], "weight_decay_mode": 1, }] optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) self.model = orttrainer.ORTTrainer(self.model, self.new_model_desc, optim_config, options=options) else: def map_optimizer_attributes(name): no_decay = "bias" in name or "LayerNorm.weight" in name if no_decay: return {"weight_decay_mode": 1} else: return {"weight_decay_mode": 1} get_lr_this_step = get_linear_schedule_with_warmup( self.args.warmup_steps, t_total, self.args.learning_rate) loss_scaler = LossScaler( 'loss_scale_input_name', True, up_scale_window=2000) if self.args.fp16 else None self.model = ORTTrainer( self.model, None, self.model_desc, "AdamOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription( 'Learning_Rate', [ 1, ], torch.float32), device=self.args.device, gradient_accumulation_steps=self.args. gradient_accumulation_steps, world_rank=max(0, self.args.local_rank), world_size=int(self.world_size), use_mixed_precision=self.args.fp16, allreduce_post_accumulation=True, get_lr_this_step=get_lr_this_step, loss_scaler=loss_scaler, enable_grad_norm_clip=False, _opset_version=12, _use_deterministic_compute=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataloader.dataset)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 tr_loss = 0.0 logging_loss = 0.0 train_iterator = trange( epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], ) for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self._training_step(self.model, inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): global_step += 1 if self.args.local_rank in [-1, 0]: if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (global_step == 1 and self.args.logging_first_step): logs = {} if self.args.evaluate_during_training: results = self.evaluate() for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value loss_scalar = (tr_loss - logging_loss ) / self.args.logging_steps if not self.use_new_api: learning_rate_scalar = get_lr_this_step( global_step) logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss epoch_iterator.write( json.dumps({ **logs, **{ "step": global_step } })) if self.args.max_steps > 0 and global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and global_step > self.args.max_steps: train_iterator.close() break logger.info("\n\nTraining completed. \n\n") return TrainOutput(global_step, tr_loss / global_step)