def testWrapModelLossFnStateDict(self): torch.manual_seed(1) device = torch.device("cuda") class LinearModel(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(2, 4) def forward(self, y=None, x=None): if y is not None: return self.linear(x) + y else: return self.linear(x) + torch.ones(2, 4) pt_model = LinearModel() data = torch.randn(2, 2) label = torch.tensor([0, 1], dtype=torch.int64) input_desc = IODescription('x', [2, 2], torch.float32) label_desc = IODescription('label', [2, ], torch.int64, num_classes=4) output_desc = IODescription('output', [2, 4], torch.float32) loss_desc = IODescription('loss', [], torch.float32) model_desc = ModelDescription([input_desc, label_desc], [loss_desc, output_desc]) def loss_fn(x, label): return F.nll_loss(F.log_softmax(x, dim=1), label) def get_lr_this_step(global_step): learningRate = 0.02 return torch.tensor([learningRate]) ort_trainer = ORTTrainer( pt_model, loss_fn, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [1, ], torch.float32), device, get_lr_this_step=get_lr_this_step) ort_trainer.train_step(x=data, label=label) state_dict = ort_trainer.state_dict() assert state_dict.keys() == {'linear.bias', 'linear.weight'}
def bart_model_description(args): vocab_size = 50349 batch = 3 max_tokens_valid = 1023 max_tokens = 3069 #''' # allow variable input sizes: src_tokens_desc = IODescription('src_tokens', ['batch', 'max_src_tokens'], torch.int64, num_classes=vocab_size) src_lengths_desc = IODescription('src_lengths', ['batch'], torch.int64, num_classes=args.max_tokens_valid) prev_output_tokens_desc = IODescription('prev_output_tokens', ['batch', 'max_out_tokens'], torch.int64, num_classes=vocab_size) target_desc = IODescription('target', ['max_tgt_tokens'], torch.int64, num_classes=vocab_size) #''' ''' # set concrete input sizes to permit optimization src_tokens_desc = IODescription('src_tokens', [batch, max_tokens_valid], torch.int64, num_classes = vocab_size) src_lengths_desc = IODescription('src_lengths', [batch], torch.int64, num_classes = args.max_tokens_valid) prev_output_tokens_desc = IODescription('prev_output_tokens', [batch, max_tokens_valid], torch.int64, num_classes = vocab_size) target_desc = IODescription('target', [max_tokens], torch.int64, num_classes = vocab_size) ''' loss_desc = IODescription('loss', [], torch.float32) #return ModelDescription([src_tokens_desc, src_lengths_desc, prev_output_tokens_desc, target_desc], [loss_desc]) return ModelDescription( [src_tokens_desc, prev_output_tokens_desc, target_desc], [loss_desc])
def test_layer_norm(self): class LayerNormNet(nn.Module): def __init__(self, target): super(LayerNormNet, self).__init__() self.ln_1 = nn.LayerNorm(10) self.loss = nn.CrossEntropyLoss() self.target = target def forward(self, x): output1 = self.ln_1(x) loss = self.loss(output1, self.target) return loss, output1 device = torch.device("cpu") target = torch.ones(20, 10, 10, dtype=torch.int64).to(device) model = LayerNormNet(target) input = torch.randn(20, 5, 10, 10, dtype=torch.float32).to(device) input_desc = IODescription('input', [], "float32") output0_desc = IODescription('output0', [], "float32") output1_desc = IODescription('output1', [20, 5, 10, 10], "float32") model_desc = ModelDescription([input_desc], [output0_desc, output1_desc]) learning_rate = torch.tensor([1.0000000e+00]).to(device) input_args=[input, learning_rate] onnx_model = self.get_onnx_model(model, model_desc, input_args, device) count_layer_norm = self.count_nodes(onnx_model, "LayerNormalization") count_nodes = self.count_all_nodes(onnx_model) assert count_layer_norm == 1 assert count_nodes == 3
def model_description(): input_desc = IODescription('src', [bptt, batch_size], torch.float32) label_desc = IODescription('label', [bptt, batch_size, ntokens], torch.int64) loss_desc = IODescription('loss', [], torch.float32) output_desc = IODescription('output', [bptt, batch_size, ntokens], torch.float32) return ModelDescription([input_desc, label_desc], [loss_desc, output_desc])
def transformer_model_description(): input_desc = IODescription('input1', [bptt, batch_size], torch.float32) label_desc = IODescription('label', [bptt, batch_size, ntokens], torch.int64) loss_desc = IODescription('loss', [], torch.float32) return ModelDescription([input_desc, label_desc], [loss_desc]), IODescription( 'Learning_Rate', [ lr, ], torch.float32)
def mnist_model_description(): input_desc = IODescription('input1', ['batch', 784], torch.float32) label_desc = IODescription('label', [ 'batch', ], torch.int64, num_classes=10) loss_desc = IODescription('loss', [], torch.float32) probability_desc = IODescription('probability', ['batch', 10], torch.float32) return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc])
def mnist_model_description(): input_desc = IODescription("input1", ["batch", 784], torch.float32) label_desc = IODescription( "label", [ "batch", ], torch.int64, num_classes=10, ) loss_desc = IODescription("loss", [], torch.float32) probability_desc = IODescription("probability", ["batch", 10], torch.float32) return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc])
def testTrainingAndEvalDropout(self): # Temporarily disable this test. # The graph below will trigger ORT # to sort backward graph before forward graph which gives incorrect result. # TODO Re-enable when that is fixed. return class TwoDropoutNet(nn.Module): def __init__(self, drop_prb_1, drop_prb_2, dim_size): super(TwoDropoutNet, self).__init__() self.drop_1 = nn.Dropout(drop_prb_1) self.drop_2 = nn.Dropout(drop_prb_2) self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32)) def forward(self, x): x = x + self.weight_1 x = self.drop_1(x) x = self.drop_2(x) output = x return output[0] dim_size = 3 device = torch.device("cuda", 0) # This will drop all values, therefore expecting all 0 in output tensor model = TwoDropoutNet(0.999, 0.999, dim_size) input_desc = IODescription('input', [dim_size], torch.float32) output_desc = IODescription('output', [], torch.float32) model_desc = ModelDescription([input_desc], [output_desc]) lr_desc = ort_trainer_learning_rate_description() model = ORTTrainer(model, None, model_desc, "LambOptimizer", map_optimizer_attributes, lr_desc, device, postprocess_model=process_dropout, world_rank=0, world_size=1) input = torch.ones(dim_size, dtype=torch.float32).to(device) expected_training_output = [0.0] expected_eval_output = [1.0] learning_rate = torch.tensor([1.0000000e+00]).to(device) input_args=[input, learning_rate] train_output = model.train_step(*input_args) rtol = 1e-04 assert_allclose(expected_training_output, train_output.item(), rtol=rtol, err_msg="dropout training loss mismatch") eval_output = model.eval_step(input) assert_allclose(expected_eval_output, eval_output.item(), rtol=rtol, err_msg="dropout eval loss mismatch") # Do another train step to make sure it's using original ratios train_output_2 = model.train_step(*input_args) assert_allclose(expected_training_output, train_output_2.item(), rtol=rtol, err_msg="dropout training loss 2 mismatch")
def get_trainer( self, model, model_desc, device, onnx_opset_ver=12, frozen_weights=[], internal_loss_fn=False, get_lr_this_step=None, optimizer="SGDOptimizer", ): loss_fn = MNISTWrapper.my_loss if not internal_loss_fn else None return ORTTrainer( model, loss_fn, model_desc, optimizer, None, IODescription( "Learning_Rate", [ 1, ], torch.float32, ), device, _opset_version=onnx_opset_ver, frozen_weights=frozen_weights, get_lr_this_step=get_lr_this_step, )
def create_ort_trainer(args, device, model): # set GPU memory limitation from onnxruntime.capi._pybind_state import set_cuda_mem_limit ort_cuda_mem_limit_in_gbs = 1 set_cuda_mem_limit(int(ort_cuda_mem_limit_in_gbs * 1024 * 1024 *1024)) # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 def map_optimizer_attributes(name): no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] no_decay = False for no_decay_key in no_decay_keys: if no_decay_key in name: no_decay = True break if no_decay: return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} else: return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6} # we request ORTTrainer to create a LambOptimizer with given optimizer_attributes. # train_step does forward, backward, and optimize step. model = ORTTrainer(model, None, bert_model_description(args), "LambOptimizer", map_optimizer_attributes, IODescription('Learning_Rate', [1,], torch.float32), device, _opset_version = 10) if args.fp16: setattr(args, 'ort_loss_scale', LossScaler(model.loss_scale_input_name, True, up_scale_window=2000)) return model
def get_onnx_model(self, model, model_desc, inputs, device, _enable_internal_postprocess=True, _extra_postprocess=None): lr_desc = IODescription('Learning_Rate', [ 1, ], torch.float32) model = ORTTrainer( model, None, model_desc, "LambOptimizer", map_optimizer_attributes, lr_desc, device, world_rank=0, world_size=1, _opset_version=12, _enable_internal_postprocess=_enable_internal_postprocess, _extra_postprocess=_extra_postprocess) train_output = model.train_step(*inputs) return model.onnx_model_
def ort_trainer_learning_rate_description(): return IODescription( "Learning_Rate", [ 1, ], torch.float32, )
def create_ort_trainer(args, device, model): # set GPU memory limitation (per card!) from onnxruntime.capi._pybind_state import set_cuda_mem_limit ort_cuda_mem_limit_in_gbs = args.gpu_memory_limit_gb set_cuda_mem_limit(int(ort_cuda_mem_limit_in_gbs * 1024 * 1024 * 1024)) # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 def map_optimizer_attributes(name): no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] no_decay = False for no_decay_key in no_decay_keys: if no_decay_key in name: no_decay = True break if no_decay: return { "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 } else: return { "alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6 } # we request ORTTrainer to create a LambOptimizer with given optimizer_attributes. # train_step does forward, backward, and optimize step. model = ORTTrainer( model, None, bert_model_description(args), "LambOptimizer", map_optimizer_attributes, IODescription('Learning_Rate', [ 1, ], torch.float32), device, gradient_accumulation_steps=args.gradient_accumulation_steps, world_rank=args.world_rank, world_size=args.world_size, use_mixed_precision=True if args.fp16 else False, allreduce_post_accumulation=True if args.allreduce_post_accumulation else False, deepspeed_zero_stage=1 if args.deepspeed_zero_stage else 0, _opset_version=12) if args.fp16: setattr( args, 'ort_loss_scale', LossScaler(model.loss_scale_input_name, True, up_scale_window=2000)) return model
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertModel(config=config) model.to(input_ids.device) model.eval() sequence_output, pooled_output = model( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) # failed because there is not loss output model_desc = ModelDescription([ self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc ], [self.last_hidden_state_desc, self.pooler_output_desc]) args_gradient_accumulation_steps = 8 args_local_rank = 0 args_world_size = 1 args_fp16 = True args_allreduce_post_accumulation = True model = ORTTrainer( model, None, model_desc, "LambOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription( 'Learning_Rate', [ 1, ], torch.float32), device=self.device, postprocess_model=postprocess_model, gradient_accumulation_steps=args_gradient_accumulation_steps, world_rank=args_local_rank, world_size=args_world_size, use_mixed_precision=True if args_fp16 else False, allreduce_post_accumulation=True if args_allreduce_post_accumulation else False) sequence_output, pooled_output = model( input_ids, token_type_ids=token_type_ids) sequence_output, pooled_output = model(input_ids) result = { "sequence_output": sequence_output, "pooled_output": pooled_output, } self.parent.assertListEqual( list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def gpt2_model_description(self, n_head, vocab_size, n_hidden, n_layer, n_ctx, batch_size): logger.info("****num of head is: {}".format(n_head)) logger.info("****vocab size is: {}".format(vocab_size)) logger.info("****num of hidden layer is: {}".format(n_hidden)) logger.info("****num of layer is: {}".format(n_layer)) logger.info("****seq length is: {}".format(n_ctx)) input_ids_desc = IODescription('input_ids', [batch_size, n_ctx], torch.int64, num_classes=vocab_size) labels_desc = IODescription('labels', [batch_size, n_ctx], torch.int64, num_classes=vocab_size) loss_desc = IODescription('loss', [], torch.float32) return ModelDescription([input_ids_desc, labels_desc], [loss_desc])
def get_trainer(self, model, model_desc, device, onnx_opset_ver=12): return ORTTrainer(model, MNISTWrapper.my_loss, model_desc, "SGDOptimizer", None, IODescription('Learning_Rate', [ 1, ], torch.float32), device, _opset_version=onnx_opset_ver)
def test_expand(self): class ExpandNet(nn.Module): def __init__(self, target): super(ExpandNet, self).__init__() self.loss = nn.CrossEntropyLoss() self.target = target self.linear = torch.nn.Linear(2, 2) def forward(self, x, x1): output = x.expand_as(x1) output = self.linear(output) output = output + output loss = self.loss(output, self.target) return loss, output device = torch.device("cpu") target = torch.ones(5, 5, 2, dtype=torch.int64).to(device) model = ExpandNet(target).to(device) x = torch.randn(5, 3, 1, 2, dtype=torch.float32).to(device) x1 = torch.randn(5, 3, 5, 2, dtype=torch.float32).to(device) input0_desc = IODescription('x', [5, 3, 1, 2], "float32") input1_desc = IODescription('x1', [5, 3, 5, 2], "float32") output0_desc = IODescription('output0', [], "float32") output1_desc = IODescription('output1', [5, 3, 5, 2], "float32") model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) learning_rate = torch.tensor([1.0000000e+00]).to(device) input_args = [x, x1, learning_rate] onnx_model = self.get_onnx_model(model, model_desc, input_args, device) # check that expand output has shape expand_nodes = self.find_nodes(onnx_model, "Expand") assert len(expand_nodes) == 1 model_info = onnx_model.graph.value_info assert model_info[0].name == expand_nodes[0].output[0] assert model_info[0].type == onnx_model.graph.input[1].type
def bert_model_description(): vocab_size = 30528 input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=vocab_size) segment_ids_desc = IODescription('segment_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) input_mask_desc = IODescription('input_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) masked_lm_labels_desc = IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=vocab_size) next_sentence_labels_desc = IODescription('next_sentence_labels', [ 'batch', ], torch.int64, num_classes=2) loss_desc = IODescription('loss', [], torch.float32) return ModelDescription([ input_ids_desc, segment_ids_desc, input_mask_desc, masked_lm_labels_desc, next_sentence_labels_desc ], [loss_desc])
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertForMaskedLM(config=config) model.eval() loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels) ##### model_desc = ModelDescription([ self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc ], [self.loss_desc, self.prediction_scores_desc]) args_gradient_accumulation_steps = 8 args_local_rank = 0 args_world_size = 1 args_fp16 = True args_allreduce_post_accumulation = True model = ORTTrainer( model, None, model_desc, "LambOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription( 'Learning_Rate', [ 1, ], torch.float32), device=self.device, postprocess_model=postprocess_model, gradient_accumulation_steps=args_gradient_accumulation_steps, world_rank=args_local_rank, world_size=args_world_size, use_mixed_precision=True if args_fp16 else False, allreduce_post_accumulation=True if args_allreduce_post_accumulation else False) model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
def bert_model_description(args): vocab_size = 30528 # allow variable input sizes: # input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes = vocab_size) # segment_ids_desc = IODescription('segment_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes = 2) # input_mask_desc = IODescription('input_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes = 2) # masked_lm_labels_desc = IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes = vocab_size) # next_sentence_labels_desc = IODescription('next_sentence_labels', ['batch',], torch.int64, num_classes = 2) # set concrete input sizes to permit optimization input_ids_desc = IODescription('input_ids', [args.train_batch_size, args.max_seq_length], torch.int64, num_classes = vocab_size) segment_ids_desc = IODescription('segment_ids', [args.train_batch_size, args.max_seq_length], torch.int64, num_classes = 2) input_mask_desc = IODescription('input_mask', [args.train_batch_size, args.max_seq_length], torch.int64, num_classes = 2) masked_lm_labels_desc = IODescription('masked_lm_labels', [args.train_batch_size, args.max_seq_length], torch.int64, num_classes = vocab_size) next_sentence_labels_desc = IODescription('next_sentence_labels', [args.train_batch_size,2], torch.int64, num_classes = 2) loss_desc = IODescription('loss', [], torch.float32) return ModelDescription([input_ids_desc, segment_ids_desc, input_mask_desc, masked_lm_labels_desc, next_sentence_labels_desc], [loss_desc])
def bert_model_description(): vocab_size = 30528 input_ids_desc = IODescription( "input_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=vocab_size, ) segment_ids_desc = IODescription("segment_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) input_mask_desc = IODescription("input_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) masked_lm_labels_desc = IODescription( "masked_lm_labels", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=vocab_size, ) next_sentence_labels_desc = IODescription( "next_sentence_labels", [ "batch", ], torch.int64, num_classes=2, ) loss_desc = IODescription("loss", [], torch.float32) return ModelDescription( [ input_ids_desc, segment_ids_desc, input_mask_desc, masked_lm_labels_desc, next_sentence_labels_desc, ], [loss_desc], )
def model_to_desc(self, model_name, model): if model_name.startswith('bert') or model_name.startswith('xlnet'): new_model_desc = { 'inputs': [( 'input_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'attention_mask', ['batch', 'max_seq_len_in_batch'], ), ( 'token_type_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'labels', [ 'batch', ], )], 'outputs': [('loss', [], True), ('logits', ['batch', 2])] } model_desc = ModelDescription([ IODescription('input_ids', ['batch', 'max_seq_len_in_batch']), IODescription('attention_mask', ['batch', 'max_seq_len_in_batch']), IODescription('token_type_ids', ['batch', 'max_seq_len_in_batch']), IODescription('labels', [ 'batch', ]) ], [ IODescription('loss', []), IODescription('logits', ['batch', 2]) ]) elif model_name.startswith('roberta'): new_model_desc = { 'inputs': [( 'input_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'attention_mask', ['batch', 'max_seq_len_in_batch'], ), ( 'labels', [ 'batch', ], )], 'outputs': [('loss', [], True), ('logits', ['batch', 2])] } model_desc = ModelDescription([ IODescription('input_ids', ['batch', 'max_seq_len_in_batch']), IODescription('attention_mask', ['batch', 'max_seq_len_in_batch']), IODescription('labels', [ 'batch', ]) ], [ IODescription('loss', []), IODescription('logits', ['batch', 2]) ]) else: raise RuntimeError( "unsupported base model name {}.".format(model_name)) return model_desc, new_model_desc
def run_multiple_choice(self, model_name, task_name, fp16): model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) data_args = DataTrainingArguments(task_name=task_name, data_dir=self.data_dir, max_seq_length=self.max_seq_length) training_args = TrainingArguments( output_dir=os.path.join(self.output_dir, task_name), do_train=True, do_eval=True, per_gpu_train_batch_size=self.train_batch_size, per_gpu_eval_batch_size=self.eval_batch_size, learning_rate=self.learning_rate, num_train_epochs=self.num_train_epochs, local_rank=self.local_rank, overwrite_output_dir=self.overwrite_output_dir, gradient_accumulation_steps=self.gradient_accumulation_steps, fp16=fp16, logging_steps=self.logging_steps) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, ) logger.warning( "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", training_args.local_rank, training_args.device, training_args.n_gpu, bool(training_args.local_rank != -1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) set_seed(training_args.seed) onnxruntime.set_seed(training_args.seed) try: processor = SwagProcessor() label_list = processor.get_labels() num_labels = len(label_list) except KeyError: raise ValueError("Task not found: %s" % (data_args.task_name)) config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, finetuning_task=data_args.task_name, cache_dir=model_args.cache_dir, ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) model = AutoModelForMultipleChoice.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), config=config, cache_dir=model_args.cache_dir, ) # Get datasets train_dataset = (MultipleChoiceDataset( data_dir=data_args.data_dir, tokenizer=tokenizer, task=data_args.task_name, processor=processor, max_seq_length=data_args.max_seq_length, overwrite_cache=data_args.overwrite_cache, mode=Split.train, ) if training_args.do_train else None) eval_dataset = (MultipleChoiceDataset( data_dir=data_args.data_dir, tokenizer=tokenizer, task=data_args.task_name, processor=processor, max_seq_length=data_args.max_seq_length, overwrite_cache=data_args.overwrite_cache, mode=Split.dev, ) if training_args.do_eval else None) def compute_metrics(p: EvalPrediction) -> Dict: preds = np.argmax(p.predictions, axis=1) return {"acc": simple_accuracy(preds, p.label_ids)} if model_name.startswith('bert'): model_desc = ModelDescription([ IODescription('input_ids', [ self.train_batch_size, num_labels, data_args.max_seq_length ], torch.int64, num_classes=model.config.vocab_size), IODescription('attention_mask', [ self.train_batch_size, num_labels, data_args.max_seq_length ], torch.int64, num_classes=2), IODescription('token_type_ids', [ self.train_batch_size, num_labels, data_args.max_seq_length ], torch.int64, num_classes=2), IODescription('labels', [self.train_batch_size, num_labels], torch.int64, num_classes=num_labels) ], [ IODescription('loss', [], torch.float32), IODescription('reshaped_logits', [self.train_batch_size, num_labels], torch.float32) ]) else: model_desc = ModelDescription([ IODescription('input_ids', ['batch', num_labels, 'max_seq_len_in_batch'], torch.int64, num_classes=model.config.vocab_size), IODescription('attention_mask', ['batch', num_labels, 'max_seq_len_in_batch'], torch.int64, num_classes=2), IODescription('labels', ['batch', num_labels], torch.int64, num_classes=num_labels) ], [ IODescription('loss', [], torch.float32), IODescription('reshaped_logits', ['batch', num_labels], torch.float32) ]) # Initialize the ORTTrainer within ORTTransformerTrainer trainer = ORTTransformerTrainer( model=model, model_desc=model_desc, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_metrics, ) # Training if training_args.do_train: trainer.train() trainer.save_model() # Evaluation results = {} if training_args.do_eval and training_args.local_rank in [-1, 0]: logger.info("*** Evaluate ***") result = trainer.evaluate() logger.info("***** Eval results {} *****".format( data_args.task_name)) for key, value in result.items(): logger.info(" %s = %s", key, value) results.update(result) return results
def test_extra_postpass(self): def postpass_replace_first_add_with_sub(model): # this post pass replaces the first Add node with Sub in the model. # Previous graph # (subgraph 1) (subgraph 2) # | | # | | # |________ ________| # | | # Add # | # (subgraph 3) # # Post graph # (subgraph 1) (subgraph 2) # | | # | | # |________ ________| # | | # Sub # | # (subgraph 3) add_nodes = [n for n in model.graph.node if n.op_type == 'Add'] add_nodes[0].op_type = "Sub" class MultiAdd(nn.Module): def __init__(self, target): super(MultiAdd, self).__init__() self.loss = nn.CrossEntropyLoss() self.target = target self.linear = torch.nn.Linear(2, 2, bias=False) def forward(self, x, x1): output = x + x1 output = output + x output = output + x1 output = self.linear(output) loss = self.loss(output, self.target) return loss, output device = torch.device("cpu") target = torch.ones(5, 2, dtype=torch.int64).to(device) model = MultiAdd(target).to(device) x = torch.randn(5, 5, 2, dtype=torch.float32).to(device) x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device) input0_desc = IODescription('x', [5, 5, 2], "float32") input1_desc = IODescription('x1', [5, 5, 2], "float32") output0_desc = IODescription('output0', [], "float32") output1_desc = IODescription('output1', [5, 5, 2], "float32") model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) learning_rate = torch.tensor([1.0000000e+00]).to(device) input_args = [x, x1, learning_rate] onnx_model = self.get_onnx_model(model, model_desc, input_args, device, _extra_postprocess=postpass_replace_first_add_with_sub) # check that extra postpass is called, and called only once. add_nodes = self.find_nodes(onnx_model, "Add") sub_nodes = self.find_nodes(onnx_model, "Sub") assert len(add_nodes) == 2 assert len(sub_nodes) == 1 unprocessed_onnx_model = self.get_onnx_model(model, model_desc, input_args, device, _extra_postprocess=None, _enable_internal_postprocess=False) # check that the model is unchanged. add_nodes = self.find_nodes(unprocessed_onnx_model, "Add") sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub") assert len(add_nodes) == 3 assert len(sub_nodes) == 0 processed_onnx_model = self.get_onnx_model(unprocessed_onnx_model, model_desc, input_args, device, _extra_postprocess=postpass_replace_first_add_with_sub) # check that extra postpass is called, and called only once. add_nodes = self.find_nodes(processed_onnx_model, "Add") sub_nodes = self.find_nodes(processed_onnx_model, "Sub") assert len(add_nodes) == 2 assert len(sub_nodes) == 1
def train(self): """ Main training entry point. """ train_dataloader = self.get_train_dataloader() if self.args.max_steps > 0: t_total = self.args.max_steps num_train_epochs = (self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1) else: t_total = int( len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs get_lr_this_step = get_linear_schedule_with_warmup( self.args.warmup_steps, t_total, self.args.learning_rate) loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) def map_optimizer_attributes(name): # no_decay_keys = ["bias", "LayerNorm.weight"] no_decay = "bias" in name or "LayerNorm.weight" in name if no_decay: return {"weight_decay": 0.0, "weight_decay_mode": 1} else: return { "weight_decay": self.args.weight_decay, "weight_decay_mode": 1 } self.model = ORTTrainer( self.model, None, self.model_desc, "AdamOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription('Learning_Rate', [ 1, ], torch.float32), device=self.args.device, gradient_accumulation_steps=self.args.gradient_accumulation_steps, world_rank=0, world_size=1, # only support single GPU cases use_mixed_precision=self.args.fp16, allreduce_post_accumulation=True, get_lr_this_step=get_lr_this_step, loss_scaler=loss_scaler, enable_grad_norm_clip=False, _opset_version=12, _use_deterministic_compute=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataloader.dataset)) logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", self.args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", self.args.train_batch_size * self.args.gradient_accumulation_steps * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 tr_loss = 0.0 logging_loss = 0.0 train_iterator = trange( epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], ) for epoch in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=self.args.local_rank not in [-1, 0]) for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue tr_loss += self._training_step(self.model, inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator)): global_step += 1 if self.args.local_rank in [-1, 0]: if (self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0) or (global_step == 1 and self.args.logging_first_step): logs = {} if self.args.evaluate_during_training: results = self.evaluate() for key, value in results.items(): eval_key = "eval_{}".format(key) logs[eval_key] = value loss_scalar = (tr_loss - logging_loss ) / self.args.logging_steps learning_rate_scalar = get_lr_this_step( global_step) logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss epoch_iterator.write( json.dumps({ **logs, **{ "step": global_step } })) if self.args.max_steps > 0 and global_step > self.args.max_steps: epoch_iterator.close() break if self.args.max_steps > 0 and global_step > self.args.max_steps: train_iterator.close() break logger.info("\n\nTraining completed. \n\n") return TrainOutput(global_step, tr_loss / global_step)
def main(): # Training settings parser = argparse.ArgumentParser(description="PyTorch MNIST Example") parser.add_argument("--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)") parser.add_argument("--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)") parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)") parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument( "--log-interval", type=int, default=10, metavar="N", help="how many batches to wait before logging training status", ) args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) kwargs = {"num_workers": 0, "pin_memory": True} train_loader = torch.utils.data.DataLoader( datasets.MNIST( "../data", train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]), ), batch_size=args.batch_size, shuffle=True, **kwargs, ) test_loader = torch.utils.data.DataLoader( datasets.MNIST( "../data", train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ]), ), batch_size=args.test_batch_size, shuffle=True, **kwargs, ) comm = MPI.COMM_WORLD args.local_rank = (int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) if ("OMPI_COMM_WORLD_LOCAL_RANK" in os.environ) else 0) args.world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) if ( "OMPI_COMM_WORLD_RANK" in os.environ) else 0 args.world_size = comm.Get_size() if use_cuda: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) args.n_gpu = 1 set_cuda_device_id(args.local_rank) else: device = torch.device("cpu") input_size = 784 hidden_size = 500 num_classes = 10 model = NeuralNet(input_size, hidden_size, num_classes) model_desc = mnist_model_description() # use log_interval as gradient accumulate steps trainer = ORTTrainer( model, my_loss, model_desc, "SGDOptimizer", None, IODescription( "Learning_Rate", [ 1, ], torch.float32, ), device, 1, args.world_rank, args.world_size, use_mixed_precision=False, allreduce_post_accumulation=True, ) print("\nBuild ort model done.") for epoch in range(1, args.epochs + 1): train_with_trainer(args, trainer, device, train_loader, epoch) import pdb test_with_trainer(args, trainer, device, test_loader)
def __init__( self, parent, batch_size=13, seq_length=7, is_training=True, use_input_mask=True, use_token_type_ids=True, use_labels=True, vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=16, type_sequence_label_size=2, initializer_range=0.02, num_labels=3, num_choices=4, scope=None, device='cpu', ): self.parent = parent self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training self.use_input_mask = use_input_mask self.use_token_type_ids = use_token_type_ids self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range self.num_labels = num_labels self.num_choices = num_choices self.scope = scope self.device = device # 1. superset of bert input/output descs # see BertPreTrainedModel doc self.input_ids_desc = IODescription( 'input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size) self.attention_mask_desc = IODescription( 'attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) self.token_type_ids_desc = IODescription( 'token_type_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) self.position_ids_desc = IODescription( 'position_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.max_position_embeddings) self.head_mask_desc = IODescription( 'head_mask', [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2) self.inputs_embeds_desc = IODescription( 'inputs_embeds', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) self.encoder_hidden_states_desc = IODescription( 'encoder_hidden_states', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) self.encoder_attention_mask_desc = IODescription( 'encoder_attention_mask', ['batch', 'max_seq_len_in_batch'], torch.float32) # see BertForPreTraining doc self.masked_lm_labels_desc = IODescription( 'masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size) self.next_sentence_label_desc = IODescription( 'next_sentence_label', [ 'batch', ], torch.int64, num_classes=2) # outputs self.loss_desc = IODescription('loss', [ 1, ], torch.float32) self.prediction_scores_desc = IODescription( 'prediction_scores', ['batch', 'max_seq_len_in_batch', self.vocab_size], torch.float32) self.seq_relationship_scores_desc = IODescription( 'seq_relationship_scores', ['batch', 2], torch.float32 ) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32) self.hidden_states_desc = IODescription('hidden_states', [ self.num_hidden_layers, 'batch', 'max_seq_len_in_batch', self.hidden_size ], torch.float32) self.attentions_desc = IODescription('attentions', [ self.num_hidden_layers, 'batch', self.num_attention_heads, 'max_seq_len_in_batch', 'max_seq_len_in_batch' ], torch.float32) self.last_hidden_state_desc = IODescription( 'last_hidden_state', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) self.pooler_output_desc = IODescription( 'pooler_output', ['batch', self.hidden_size], torch.float32)
def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16, allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler, batch_args_option, dataset_len, epochs, use_new_api): dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) if use_new_api: assert use_internal_loss_scaler, 'new api should always use internal loss scaler' new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step) new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None options = orttrainer.ORTTrainerOptions({ 'batch': { 'gradient_accumulation_steps': gradient_accumulation_steps }, 'device': { 'id': device }, 'mixed_precision': { 'enabled': fp16, 'loss_scaler': new_api_loss_scaler }, 'debug': { 'deterministic_compute': True, }, 'utils': { 'grad_norm_clip': True }, 'distributed': { 'allreduce_post_accumulation': True }, 'lr_scheduler': new_api_lr_scheduler }) param_optimizer = list(model.named_parameters()) params = [{ 'params': [ n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }, { 'params': [ n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n) ], "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6 }] vocab_size = 99 new_model_desc = { 'inputs': [( 'input_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'attention_mask', ['batch', 'max_seq_len_in_batch'], ), ( 'token_type_ids', ['batch', 'max_seq_len_in_batch'], ), ( 'masked_lm_labels', ['batch', 'max_seq_len_in_batch'], ), ('next_sentence_label', [ 'batch', ])], 'outputs': [('loss', [ 1, ], True), ('prediction_scores', ['batch', 'max_seq_len_in_batch', vocab_size]), ('seq_relationship_scores', ['batch', 2])] } optim_config = optim.LambConfig(params=params, lr=2e-5) model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options) print("running with new frontend API") else: model = ORTTrainer( model, None, model_desc, "LambOptimizer", map_optimizer_attributes=map_optimizer_attributes, learning_rate_description=IODescription('Learning_Rate', [ 1, ], torch.float32), device=device, _enable_internal_postprocess=True, gradient_accumulation_steps=gradient_accumulation_steps, # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 world_rank=args.local_rank, world_size=args.world_size, use_mixed_precision=fp16, allreduce_post_accumulation=allreduce_post_accumulation, get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, loss_scaler=loss_scaler if use_internal_loss_scaler else None, _opset_version=14, _use_deterministic_compute=True) print("running with old frontend API") # trainig loop eval_batch = None if not use_new_api: model.train() for epoch in range(epochs): for step, batch in enumerate(dataloader): if eval_batch is None: eval_batch = batch if not use_internal_get_lr_this_step: lr = get_lr_this_step(step) learning_rate = torch.tensor([lr]) if not use_internal_loss_scaler and fp16: loss_scale = torch.tensor([loss_scaler.loss_scale_]) if batch_args_option == BatchArgsOption.List: if not use_internal_get_lr_this_step: batch = batch + [ learning_rate, ] if not use_internal_loss_scaler and fp16: batch = batch + [ loss_scale, ] outputs = model.train_step(*batch) elif batch_args_option == BatchArgsOption.Dict: args, kwargs = split_batch(batch, model_desc.inputs_, 0) if not use_internal_get_lr_this_step: kwargs['Learning_Rate'] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) else: args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) if not use_internal_get_lr_this_step: kwargs['Learning_Rate'] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) # eval if batch_args_option == BatchArgsOption.List: outputs = model.eval_step(*batch) elif batch_args_option == BatchArgsOption.Dict: args, kwargs = split_batch(batch, model_desc.inputs_, 0) outputs = model.eval_step(*args, **kwargs) else: args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) outputs = model.eval_step(*args, **kwargs) return (output.cpu().numpy() for output in outputs)
def model_to_desc(self, model_name, model): if model_name.startswith('bert') or model_name.startswith('xlnet'): model_desc = ModelDescription([ IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=model.config.vocab_size), IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2), IODescription('token_type_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2), IODescription( 'labels', [ 'batch', ], torch.int64, num_classes=2) ], [ IODescription('loss', [], torch.float32), IODescription('logits', ['batch', 2], torch.float32) ]) elif model_name.startswith('roberta'): model_desc = ModelDescription([ IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=model.config.vocab_size), IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2), IODescription( 'labels', [ 'batch', ], torch.int64, num_classes=2) ], [ IODescription('loss', [], torch.float32), IODescription('logits', ['batch', 2], torch.float32) ]) else: raise RuntimeError( "unsupported base model name {}.".format(model_name)) return model_desc
def main(): #Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--use-ort', action='store_true', default=False, help='to use onnxruntime as training backend') parser.add_argument('--use-ort-trainer', action='store_true', default=False, help='to use onnxruntime as training backend') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) kwargs = {'num_workers': 0, 'pin_memory': True} train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), batch_size=args.test_batch_size, shuffle=True, **kwargs) comm = MPI.COMM_WORLD args.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) if ('OMPI_COMM_WORLD_LOCAL_RANK' in os.environ) else 0 args.world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) if ('OMPI_COMM_WORLD_RANK' in os.environ) else 0 args.world_size=comm.Get_size() torch.cuda.set_device(args.local_rank) if use_cuda: device = torch.device("cuda", args.local_rank) else: device = torch.device("cpu") args.n_gpu = 1 set_cuda_device_id(args.local_rank) input_size = 784 hidden_size = 500 num_classes = 10 model = NeuralNet(input_size, hidden_size, num_classes) model_desc = mnist_model_description() if args.use_ort_trainer: # use log_interval as gradient accumulate steps trainer = ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, IODescription('Learning_Rate', [1,], torch.float32), device, 1, None, args.world_rank, args.world_size, use_mixed_precision=False, allreduce_post_accumulation = True) print('\nBuild ort model done.') for epoch in range(1, args.epochs + 1): train_with_trainer(args, trainer, device, train_loader, epoch) import pdb test_with_trainer(args, trainer, device, test_loader) else: model = ORTModel(model, my_loss, model_desc, device, None, args.world_rank, args.world_size) print('\nBuild ort model done.') optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) for epoch in range(1, args.epochs + 1): train_with_model(args, model, device, train_loader, optimizer, epoch)