def train(args, train_dataset, model, tokenizer): """ Train the model """ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler( train_dataset) if args.local_rank == -1 else DistributedSampler( train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate_fn) if args.max_steps > 0: num_training_steps = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: num_training_steps = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs args.warmup_steps = int(num_training_steps * args.warmup_proportion) # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] # optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(params=optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_training_steps) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", num_training_steps) global_step = 0 tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() seed_everything( args.seed ) # Added here for reproductibility (even between python 2 and 3) for _ in range(int(args.num_train_epochs)): pbar = ProgressBar(n_total=len(train_dataloader), desc='Training') for step, batch in enumerate(train_dataloader): model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3], 'token_type_ids': batch[2] } outputs = model(**inputs) loss = outputs[ 0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if args.local_rank == -1: # Only evaluate when single GPU otherwise metrics may not average well evaluate(args, model, tokenizer) if args.local_rank in [ -1, 0 ] and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = model.module if hasattr( model, 'module' ) else model # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, 'training_args.bin')) logger.info("Saving model checkpoint to %s", output_dir) pbar(step, {'loss': loss.item()}) print(" ") if 'cuda' in str(args.device): torch.cuda.empty_cache() return global_step, tr_loss / global_step
def train(args, train_dataset, model, tokenizer): """ Train the model """ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate_fn) if args.max_steps > 0: num_training_steps = args.max_steps args.num_train_epochs = args.max_steps // ( len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: num_training_steps = len( train_dataloader ) // args.gradient_accumulation_steps * args.num_train_epochs args.warmup_steps = int(num_training_steps * args.warmup_proportion) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] optimizer = AdamW(params=optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_training_steps) if args.n_gpu > 1: model = torch.nn.DataParallel(model) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", num_training_steps) global_step = 0 tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() seed_everything( args.seed ) # Added here for reproductibility (even between python 2 and 3) for _ in range(int(args.num_train_epochs)): pbar = ProgressBar(n_total=len(train_dataloader), desc='Training') for step, batch in enumerate(train_dataloader): model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3] } inputs['token_type_ids'] = batch[2] outputs = model(**inputs) loss = outputs[0] if args.n_gpu > 1: loss = loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() scheduler.step() model.zero_grad() global_step += 1 if args.local_rank in [ -1, 0 ] and args.logging_steps > 0 and global_step % args.logging_steps == 0: # Log metrics if args.local_rank == -1: evaluate(args, model, tokenizer) # if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: # output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) # if not os.path.exists(output_dir): # os.makedirs(output_dir) # model_to_save = model.module if hasattr(model, # 'module') else model # model_to_save.save_pretrained(output_dir) # torch.save(args, os.path.join(output_dir, 'training_args.bin')) # logger.info("Saving model checkpoint to %s", output_dir) pbar(step, {'loss': loss.item()}) print(" ") if 'cuda' in str(args.device): torch.cuda.empty_cache() # if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = model.module if hasattr(model, 'module') else model model_to_save.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, 'training_args.bin')) logger.info("Saving model checkpoint to %s", output_dir) return global_step, tr_loss / global_step
def main(): parser = ArgumentParser() ## Required parameters parser.add_argument( "--data_dir", default="dataset", type=str, help= "The input data dir. Should contain the .tsv files (or other data files) for the task." ) parser.add_argument("--config_path", default="prev_trained_model/electra_small/config.json", type=str) parser.add_argument("--vocab_path", default="prev_trained_model/electra_small/vocab.txt", type=str) parser.add_argument( "--output_dir", default="outputs", type=str, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument("--model_path", default='prev_trained_model/electra_small', type=str) parser.add_argument('--data_name', default='electra', type=str) parser.add_argument( "--file_num", type=int, default=10, help="Number of dynamic masking to pregenerate (with different masks)") parser.add_argument( "--reduce_memory", action="store_true", help= "Store training data as on-disc memmaps to massively reduce memory usage" ) parser.add_argument("--epochs", type=int, default=4, help="Number of epochs to train for") parser.add_argument( "--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") parser.add_argument('--num_eval_steps', default=100) parser.add_argument('--num_save_steps', default=2000) parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight deay if we apply some.") parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available") parser.add_argument( '--gradient_accumulation_steps', type=int, default=1, help= "Number of updates steps to accumulate before performing a backward/update pass." ) parser.add_argument("--train_batch_size", default=128, type=int, help="Total batch size for training.") parser.add_argument("--gen_weight", default=1.0, type=float, help='masked language modeling / generator loss') parser.add_argument("--disc_weight", default=50, type=float, help='discriminator loss') parser.add_argument('--untied_generator', action='store_true', help='tie all generator/discriminator weights?') parser.add_argument('--temperature', default=0, type=float, help='temperature for sampling from generator') parser.add_argument( '--loss_scale', type=float, default=0, help= "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" "0 (default value): dynamic loss scaling.\n" "Positive power of 2: static loss scaling value.\n") parser.add_argument("--warmup_proportion", default=0.1, type=float, help="Linear warmup over warmup_steps.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument('--max_grad_norm', default=1.0, type=float) parser.add_argument("--learning_rate", default=0.000176, type=float, help="The initial learning rate for Adam.") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--fp16_opt_level', type=str, default='O2', help= "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") parser.add_argument( '--fp16', action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--continue_train', default='', help="continue train path") args = parser.parse_args() args.data_dir = Path(args.data_dir) args.output_dir = Path(args.output_dir) pregenerated_data = args.data_dir / "corpus/train" init_logger(log_file=str(args.output_dir / "train_albert_model.log")) assert pregenerated_data.is_dir(), \ "--pregenerated_data should point to the folder of files made by prepare_lm_data_mask.py!" samples_per_epoch = 0 for i in range(args.file_num): data_file = pregenerated_data / f"{args.data_name}_file_{i}.json" metrics_file = pregenerated_data / f"{args.data_name}_file_{i}_metrics.json" if data_file.is_file() and metrics_file.is_file(): metrics = json.loads(metrics_file.read_text()) samples_per_epoch += metrics['num_training_examples'] else: if i == 0: exit("No training data was found!") print( f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs})." ) print( "This script will loop over the available data, but training diversity may be negatively impacted." ) break logger.info(f"samples_per_epoch: {samples_per_epoch}") if args.local_rank == -1 or args.no_cuda: device = torch.device(f"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) args.n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info( f"device: {device} , distributed training: {bool(args.local_rank != -1)}, 16-bits training: {args.fp16}" ) if args.gradient_accumulation_steps < 1: raise ValueError( f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1" ) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps seed_everything(args.seed) tokenizer = BertTokenizer.from_pretrained(args.vocab_path, do_lower_case=args.do_lower_case) total_train_examples = samples_per_epoch * args.epochs num_train_optimization_steps = int(total_train_examples / args.train_batch_size / args.gradient_accumulation_steps) if args.local_rank != -1: num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size( ) args.warmup_steps = int(num_train_optimization_steps * args.warmup_proportion) bert_config = ElectraConfig.from_pretrained(args.config_path, gen_weight=args.gen_weight, temperature=args.temperature, disc_weight=args.disc_weight) model = ElectraForPreTraining(config=bert_config) if args.continue_train: print(f"Continue train from {args.continue_train}") model = model.from_pretrained(args.continue_train) elif args.model_path: print("载入预训练模型") model.generator = AutoModel.from_pretrained(args.model_path + "/G") model.electra = AutoModel.from_pretrained(args.model_path + "/D") # print(model) model.to(device) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW(params=optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_train_optimization_steps) # optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) # if args.model_path: # optimizer.load_state_dict(torch.load(args.model_path + "/optimizer.bin")) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) if args.n_gpu > 1: # model = BalancedDataParallel(gpu0_bsz=32,dim=0,model).to(device) model = torch.nn.DataParallel(model) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank) global_step = 0 g_metric = LMAccuracy() d_metric = AccuracyThresh() tr_g_acc = AverageMeter() tr_d_acc = AverageMeter() tr_loss = AverageMeter() tr_g_loss = AverageMeter() tr_d_loss = AverageMeter() train_logs = {} logger.info("***** Running training *****") logger.info(f" Num examples = {total_train_examples}") logger.info(f" Batch size = {args.train_batch_size}") logger.info(f" Num steps = {num_train_optimization_steps}") logger.info(f" warmup_steps = {args.warmup_steps}") logger.info(f" Num workable gpus = {args.n_gpu}") start_time = time.time() seed_everything(args.seed) # Added here for reproducibility for epoch in range(args.epochs): for idx in range(args.file_num): epoch_dataset = PregeneratedDataset( file_id=idx, training_path=pregenerated_data, tokenizer=tokenizer, reduce_memory=args.reduce_memory, data_name=args.data_name) if args.local_rank == -1: train_sampler = RandomSampler(epoch_dataset) else: train_sampler = DistributedSampler(epoch_dataset) train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size) model.train() nb_tr_examples, nb_tr_steps = 0, 0 for step, batch in enumerate(train_dataloader): batch = tuple(t.to(device) for t in batch) input_ids, input_mask, segment_ids, lm_label_ids = batch outputs = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, masked_lm_labels=lm_label_ids) loss, g_loss, d_loss, d_logits, g_logits, is_replaced_label = outputs active_indices = input_mask.view(-1) == 1 active_logits = d_logits.view(-1)[active_indices] active_labels = is_replaced_label.view(-1)[active_indices] g_metric(logits=g_logits.view(-1, bert_config.vocab_size), target=lm_label_ids.view(-1)) d_metric(logits=active_logits.view(-1, 1), target=active_labels) if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. g_loss = g_loss.mean() d_loss = d_loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() nb_tr_steps += 1 tr_g_acc.update(g_metric.value(), n=input_ids.size(0)) tr_d_acc.update(d_metric.value(), n=input_ids.size(0)) tr_loss.update(loss.item(), n=1) tr_g_loss.update(g_loss.item(), n=1) tr_d_loss.update(d_loss.item(), n=1) if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) scheduler.step() optimizer.step() optimizer.zero_grad() global_step += 1 if global_step % args.num_eval_steps == 0: now = time.time() eta = now - start_time if eta > 3600: eta_format = ('%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, eta % 60)) elif eta > 60: eta_format = '%d:%02d' % (eta // 60, eta % 60) else: eta_format = '%ds' % eta train_logs['loss'] = tr_loss.avg train_logs['g_acc'] = tr_g_acc.avg train_logs['d_acc'] = tr_d_acc.avg train_logs['g_loss'] = tr_g_loss.avg train_logs['d_loss'] = tr_d_loss.avg show_info = f'[Training]:[{epoch}/{args.epochs}]{global_step}/{num_train_optimization_steps} ' \ f'- ETA: {eta_format}' + "-".join( [f' {key}: {value:.4f} ' for key, value in train_logs.items()]) logger.info(show_info) tr_g_acc.reset() tr_d_acc.reset() tr_loss.reset() tr_g_loss.reset() tr_d_loss.reset() start_time = now if global_step % args.num_save_steps == 0: if args.local_rank in [-1, 0] and args.num_save_steps > 0: # Save model checkpoint output_dir = args.output_dir / f'lm-checkpoint-{global_step}' if not output_dir.exists(): output_dir.mkdir() # save model model_to_save = model.module if hasattr( model, 'module' ) else model # Take care of distributed/parallel training model_to_save.save_pretrained(str(output_dir)) torch.save(args, str(output_dir / 'training_args.bin')) logger.info("Saving model checkpoint to %s", output_dir) model.module.generator.save_pretrained( str(output_dir / "G")) logger.info("Saving generator model checkpoint to %s", output_dir / "G") model.module.electra.save_pretrained( str(output_dir / "D")) logger.info("Saving electra model checkpoint to %s", output_dir / "D") torch.save(optimizer.state_dict(), str(output_dir / "optimizer.bin")) # save config output_config_file = output_dir / CONFIG_NAME output_config_file_D = output_dir / "D" / CONFIG_NAME output_config_file_G = output_dir / "G" / CONFIG_NAME with open(str(output_config_file), 'w') as f: f.write(model_to_save.config.to_json_string()) with open(str(output_config_file_D), 'w') as f: f.write( model.module.electra.config.to_json_string()) with open(str(output_config_file_G), 'w') as f: f.write( model.module.generator.config.to_json_string()) # save vocab tokenizer.save_vocabulary(output_dir)
def take_train_steps(args, model, tokenizer, train_dataloader, prune): if args.max_steps > 0: num_training_steps = args.max_steps args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: num_training_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs args.warmup_steps = int(num_training_steps * args.warmup_proportion) # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(params=optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_training_steps) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * ( torch.distributed.get_world_size() if args.local_rank != -1 else 1)) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", num_training_steps) for epoch in range(int(args.num_train_epochs)): pbar = ProgressBar(n_total=len(train_dataloader), desc='Training') model.train() prune.on_epoch_begin(epoch) for step, batch in enumerate(train_dataloader): prune.on_batch_begin(step) batch = tuple(t.to(args.device) for t in batch) inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3]} #inputs['token_type_ids'] = batch[2] outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in transformers (see doc) if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) if (step + 1) % args.gradient_accumulation_steps == 0: optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() prune.on_batch_end() if step >= 20: break; if args.local_rank in [-1, 0] and args.logging_steps > 0 and step % args.logging_steps == 20: # Log metrics if args.local_rank == -1: # Only evaluate when single GPU otherwise metrics may not average well evaluate(args, model, tokenizer) #if args.local_rank in [-1, 0] and args.save_steps > 0 and step % args.save_steps == 20: # # Save model checkpoint # output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(step + epoch * args.save_steps)) # if not os.path.exists(output_dir): # os.makedirs(output_dir) # model_to_save = model.module if hasattr(model, # 'module') else model # Take care of distributed/parallel training # model_to_save.save_pretrained(output_dir) # torch.save(args, os.path.join(output_dir, 'training_args.bin')) # torch.save(model, os.path.join(output_dir, 'model.bin')) # logger.info("Saving model checkpoint to %s", output_dir) pbar(step, {'loss': loss.item()}) prune.on_epoch_end() print(" ") if 'cuda' in str(args.device): torch.cuda.empty_cache()