def testToyBertCheckpointFrozenWeights(): # Common setup seed = 1 total_steps = 10 torch.manual_seed(seed) onnxruntime.set_seed(seed) opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'utils': { 'frozen_weights': ['bert.encoder.layer.0.attention.self.value.weight'] } }) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() optim_config = optim.LambConfig() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train for a few steps for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, seed) _ = trainer.train_step(*sample_input) sample_input = generate_random_input_from_model_desc( model_desc, seed + total_steps + 1) # Evaluate once to get a base loss loss = trainer.eval_step(*sample_input) # Save checkpoint state_dict = checkpoint.experimental_state_dict(trainer) # Load previous state into another instance of ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() optim_config2 = optim.LambConfig() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config2, options=opts) checkpoint.experimental_load_state_dict(trainer2, state_dict) # Evaluate once to get a base loss ckpt_loss = trainer2.eval_step(*sample_input) # Must match as both trainers have the same dict state assert_allclose(loss.cpu(), ckpt_loss.cpu()) loaded_state_dict = checkpoint.experimental_state_dict(trainer2) assert state_dict.keys() == loaded_state_dict.keys()
def testToyBertCheckpointBasic(): # Common setup seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions( {'debug': { 'deterministic_compute': True }}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) sd = checkpoint.experimental_state_dict(trainer) ## All initializers must be present in the state_dict ## when the specified model for ORTTRainer is an ONNX model for param in trainer._onnx_model.graph.initializer: assert param.name in sd ## Modify one of the state values and load into ORTTrainer sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10 checkpoint.experimental_load_state_dict(trainer, sd) ## Save a checkpoint ckpt_dir = 'testdata' checkpoint.experimental_save_checkpoint(trainer, ckpt_dir, 'bert_toy_save_test') del trainer del model # Create a new ORTTrainer and load the checkpoint from previous ORTTrainer model2 = load_bert_onnx_model() model_desc2 = bert_model_description() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) checkpoint.experimental_load_checkpoint(trainer2, ckpt_dir, 'bert_toy_save_test') loaded_sd = checkpoint.experimental_state_dict(trainer2) # Assert whether original state and the one loaded from checkpoint matches for k, v in loaded_sd.items(): assert torch.all(torch.eq(v, sd[k]))
def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir): """Instantiate and load checkpoint into trainer - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model - Loads the checkpoint from directory checkpoint_dir into the trainer - Runs eval_step on the trainer so the trainer onnx graph is initialized - Returns the trainer state_dict and the pytorch model """ seed = 1 torch.manual_seed(seed) set_seed(seed) # PyTorch transformer model setup learning_rate = 0.1 optim_config = optim.LambConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load checkpoint into trainer checkpoint.experimental_load_checkpoint(trainer, checkpoint_dir) # run an eval step to innitialize the graph torch.manual_seed(seed) set_seed(seed) data, targets = batcher_fn(train_data, 0) trainer.eval_step(data, targets) return checkpoint.experimental_state_dict(trainer), model
def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True): learning_rate = 0.1 seed = 1 torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig( lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model( device) trainer = orttrainer.ORTTrainer( model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load dummy state dummy_init_state = generate_dummy_optim_state(model, optim_config) checkpoint._experimental_load_optimizer_state(trainer, dummy_init_state) # run an eval step to innitialize the graph data, targets = batcher_fn(train_data, 0) trainer.eval_step(data, targets) return dummy_init_state, checkpoint.experimental_state_dict(trainer)
def _save(trainer, checkpoint_dir, state_dict_key_name): """Saves the ORTTrainer checkpoint and the complete state dictionary to the given checkpoint_dir directory""" # save current model parameters as a checkpoint makedir(checkpoint_dir) checkpoint.experimental_save_checkpoint(trainer, checkpoint_dir) state_dict = checkpoint.experimental_state_dict(trainer) pickle.dump({state_dict_key_name : state_dict}, open(os.path.join(checkpoint_dir, state_dict_key_name+'.pkl'), "wb"))
def update_torch_model(self, ): if self.ort_model: logger.info("Updating weights of torch model from ORT model.") ort_state_dict = checkpoint.experimental_state_dict(self.ort_model) self.model.load_state_dict(ort_state_dict, strict=False) else: logger.warning( "No ORT model found to update weights from, assuming torch model is up to date." )
def testToyBertLoadOptimState(optimizer, mixedprecision_enabled): # Common setup rtol = 1e-03 device = 'cuda' seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optimizer opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device }, 'mixed_precision': { 'enabled': mixedprecision_enabled, }, 'distributed': { 'allreduce_post_accumulation': True } }) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() model_desc = bert_model_description() dummy_init_state = _test_commons.generate_dummy_optim_state( model, optimizer) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) checkpoint._experimental_load_optimizer_state(trainer, dummy_init_state) # Expected values expected_eval_loss = [10.997552871] input_ids = torch.tensor( [[26598], [21379], [19922], [5219], [5644], [20559], [23777], [25672], [22969], [16824], [16822], [635], [27399], [20647], [18519], [15546]], device=device) segment_ids = torch.tensor([[0], [1], [0], [1], [0], [0], [1], [0], [0], [1], [1], [0], [0], [1], [1], [1]], device=device) input_mask = torch.tensor([[0], [0], [0], [0], [1], [1], [1], [0], [1], [1], [0], [0], [0], [1], [0], [0]], device=device) masked_lm_labels = torch.tensor( [[25496], [16184], [11005], [16228], [14884], [21660], [8678], [23083], [4027], [8397], [11921], [1333], [26482], [1666], [17925], [27978]], device=device) next_sentence_labels = torch.tensor( [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) # Actual values _ = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) actual_state = checkpoint.experimental_state_dict(trainer) actual_optim_state = _test_commons.get_optim_state_from_state_dict( actual_state, optimizer) _test_helpers.assert_optim_state(dummy_init_state, actual_optim_state)
def main(): args = parse_arguments() if args.use_env and 'LOCAL_RANK' in os.environ: args.local_rank = int(os.environ['LOCAL_RANK']) random.seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) torch.manual_seed(args.seed + args.local_rank) torch.cuda.manual_seed(args.seed + args.local_rank) worker_init = WorkerInitObj(args.seed + args.local_rank) device, args = setup_training(args) dllogger.log(step="PARAMETER", data={"Config": [str(args)]}) # Prepare optimizer model, checkpoint, global_step = prepare_model(args, device) if is_main_process(args): dllogger.log(step="PARAMETER", data={"SEED": args.seed}) raw_train_start = time.time() if args.do_train: if is_main_process(args): dllogger.log(step="PARAMETER", data={"train_start": True}) dllogger.log(step="PARAMETER", data={"batch_size_per_gpu": args.train_batch_size}) dllogger.log(step="PARAMETER", data={"learning_rate": args.learning_rate}) most_recent_ckpts_paths = [] average_loss = 0.0 # averaged loss every args.log_freq steps epoch = 0 training_steps = 0 pool = ProcessPoolExecutor(1) # Note: We loop infinitely over epochs, termination is handled via iteration count while True: thread = None if not args.resume_from_checkpoint or epoch > 0 or ( args.phase2 and global_step < 1) or args.init_checkpoint: files = [ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if os.path.isfile(os.path.join(args.input_dir, f)) and 'training' in f ] files.sort() num_files = len(files) random.shuffle(files) f_start_id = 0 else: f_start_id = checkpoint['files'][0] files = checkpoint['files'][1:] args.resume_from_checkpoint = False num_files = len(files) shared_file_list = {} if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() world_rank = torch.distributed.get_rank() elif hasattr(args, 'world_size'): world_size = args.world_size world_rank = args.world_rank else: world_size = 1 world_rank = 0 if world_size > num_files: remainder = world_size % num_files data_file = files[(f_start_id * world_size + world_rank + remainder * f_start_id) % num_files] elif world_size > 1: data_file = files[(f_start_id * world_size + world_rank) % num_files] else: data_file = files[f_start_id % num_files] # --- previous_file = data_file train_data = pretraining_dataset(data_file, args.max_predictions_per_seq) train_sampler = RandomSampler(train_data) # we need to skip last batch when we hard code inputs as an optimization train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=4, worker_init_fn=worker_init, pin_memory=True, drop_last=True) gpu_batch_size = args.train_batch_size // args.gradient_accumulation_steps if len(files) == 1: f_start_id = -1 for f_id in range(f_start_id + 1, len(files)): if world_size > num_files: data_file = files[(f_id * world_size + world_rank + remainder * f_id) % num_files] elif world_size > 1: data_file = files[(f_id * world_size + world_rank) % num_files] else: data_file = files[f_id % num_files] previous_file = data_file dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init) train_iter = tqdm( train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process( args) else train_dataloader prev_step_time = time.time() for step, batch in enumerate(train_iter): training_steps += 1 batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch divisor = args.gradient_accumulation_steps loss, global_step = ort_supplement.run_ort_training_step( args, global_step, training_steps, model, batch) average_loss += loss.item() if global_step >= args.max_steps: train_time_raw = time.time() - raw_train_start last_num_steps = int( training_steps / args.gradient_accumulation_steps) % args.log_freq last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps average_loss = torch.tensor( average_loss, dtype=torch.float32).cuda() average_loss = average_loss / (last_num_steps * divisor) if (torch.distributed.is_initialized()): average_loss /= torch.distributed.get_world_size() torch.distributed.all_reduce(average_loss) final_loss = average_loss.item() if is_main_process(args): dllogger.log(step=( epoch, global_step, ), data={"final_loss": final_loss}) elif training_steps % ( args.log_freq * args.gradient_accumulation_steps) == 0: throughput = (args.train_batch_size * args.gradient_accumulation_steps) / ( time.time() - prev_step_time) print('throughput = ', throughput, 'seq/sec') prev_step_time = time.time() sys.stdout.flush() if is_main_process(args): data = { "average_loss": average_loss / (args.log_freq * divisor), "step_loss": loss.item() * args.gradient_accumulation_steps / divisor } dllogger.log(step=( epoch, global_step, ), data=data) average_loss = 0 if global_step >= args.max_steps or training_steps % ( args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0: if is_main_process(args) and not args.skip_checkpoint: # Save a trained model dllogger.log(step="PARAMETER", data={"checkpoint_step": global_step}) model_to_save = model.module if hasattr( model, 'module' ) else model # Only save the model it-self if args.resume_step < 0 or not args.phase2: output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step)) else: output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step)) if args.do_train: state = { 'model': model_to_save.state_dict() if hasattr( model_to_save, 'state_dict') else experimental_state_dict(model_to_save), 'files': [f_id] + files } torch.save(state, output_save_file) most_recent_ckpts_paths.append( output_save_file) if len(most_recent_ckpts_paths) > 3: ckpt_to_be_removed = most_recent_ckpts_paths.pop( 0) os.remove(ckpt_to_be_removed) if global_step >= args.max_steps: if is_main_process(args): print( '-----------------------save onnx model-----------------------' ) if not args.phase2: model_to_save.save_as_onnx( '{}/phase1_bert.onnx'.format( args.output_dir)) else: model_to_save.save_as_onnx( '{}/final_bert.onnx'.format( args.output_dir)) del train_dataloader # thread.join() return args, final_loss, train_time_raw del train_dataloader # thread.join() # Make sure pool has finished and switch train_dataloader # NOTE: Will block until complete train_dataloader, data_file = dataset_future.result( timeout=None) epoch += 1