def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_value_on, d_k, checkpoint_name, activation): rpr_k = listify(rpr_k) if len(rpr_k) == 0 or rpr_k[0] < 1: rpr_k = None elif len(rpr_k) == 1: rpr_k = rpr_k[0] logger.info("Creating tied encoder decoder model") model = TransformerLanguageModel.create({'x': embeddings}, hsz=d_model, d_ff=d_ff, tie_weights=True, dropout=0, gpu=False, num_heads=num_heads, layers=num_layers, rpr_k=rpr_k, rpr_value_on=rpr_value_on, d_k=d_k, activation=activation, src_keys=['x'], tgt_key='x') if checkpoint_name.endswith('npz'): load_tlm_npz(model, checkpoint_name) else: tlm_load_state_dict(model, checkpoint_name) model.eval() print(model) return model
def _round_trip(embed_type, rpr_k=None): test_file = os.path.join(file_loc, "test_data", "blah.npz") d_model = 40 vocab_x = { 'a': 1, 'aardvark': 100, 'beandip': 42, 'cheerio': 86, 'dumdum': 129, 'eoyre': 3 } embeddings = {} vocabs = {'x': vocab_x} src_x_embedding = baseline.embeddings.load_embeddings( 'x', dsz=d_model, known_vocab=vocab_x, embed_type=embed_type) embeddings['x'] = src_x_embedding['embeddings'] src_model = TransformerLanguageModel.create(embeddings, hsz=d_model, dropout=0.1, gpu=False, num_heads=4, layers=2, rpr_k=rpr_k, src_keys=['x'], tgt_key='x') save_tlm_npz(src_model, test_file) dst_x_embedding = baseline.embeddings.load_embeddings( 'x', dsz=d_model, known_vocab=vocab_x, embed_type=embed_type) embeddings['x'] = dst_x_embedding['embeddings'] dst_model = TransformerLanguageModel.create(embeddings, hsz=d_model, dropout=0.1, gpu=False, num_heads=4, layers=2, rpr_k=rpr_k, src_keys=['x'], tgt_key='x') load_tlm_npz(dst_model, test_file) B = 4 T = 7 a_batch = torch.randint(0, 9, (B, T)).long() a_lengths = torch.randint(0, T, (B, )).long() out_pyt1 = _call_model(src_model, { 'x': a_batch, 'lengths': a_lengths }).detach().numpy() out_pyt2 = _call_model(dst_model, { 'x': a_batch, 'lengths': a_lengths }).detach().numpy() return np.allclose(out_pyt1, out_pyt2, atol=1e-6)
def load(cls, embeddings, **kwargs): c = cls("tlm-words-embed", **kwargs) if embeddings.endswith('.bin'): # HuggingFace checkpoint, convert on the fly from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP unmatch = load_tlm_transformers_bin( c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP) if unmatch['missing'] or unmatch['unexpected']: raise Exception("Unable to load the HuggingFace checkpoint") if mime_type(embeddings) == 'application/zip': load_tlm_npz(c, embeddings) else: tlm_load_state_dict(c, embeddings) return c
def create(cls, embeddings, **kwargs): lm = cls() lm.gpu = kwargs.get('gpu', True) lm.tgt_key = kwargs.get('tgt_key') if lm.tgt_key is None: raise Exception('Need a `tgt_key` to know which source vocabulary should be used for destination ') lm.src_keys = kwargs.get('src_keys', embeddings.keys()) lm.create_layers(embeddings, **kwargs) checkpoint_name = kwargs.get('checkpoint') if checkpoint_name is not None: if checkpoint_name.endswith('npz'): load_tlm_npz(lm, checkpoint_name) else: lm.load_state_dict(torch.load(checkpoint_name)) return lm
def load(cls, embeddings, **kwargs): c = cls("tlm-words-embed", **kwargs) if embeddings.endswith('.bin'): # HuggingFace checkpoint, convert on the fly from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP unmatch = load_tlm_transformers_bin(c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP) if unmatch['missing'] or unmatch['unexpected']: raise Exception("Unable to load the HuggingFace checkpoint") if mime_type(embeddings) == 'application/zip' and not embeddings.endswith("pth"): keys_to_restore = set(list(c.embeddings.keys())) filtered_keys = keys_to_restore.difference(c.skippable) if not keys_to_restore: raise Exception("No keys to restore!") if len(filtered_keys) < len(keys_to_restore): logger.warning("Restoring only key [%s]", ' '.join(filtered_keys)) load_tlm_npz(c, embeddings, filtered_keys) else: tlm_load_state_dict(c, embeddings) return c
def train(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--train_file", type=str, required=True, help='File path to use for train file') parser.add_argument("--valid_file", type=str, required=True, help='File path to use for valid file') parser.add_argument("--dataset_key", default="tlm", help="dataset key for basedir") parser.add_argument( "--embed_type", type=str, default='default', choices=["default", "positional", "learned-positional"], help="register label of the embeddings") parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument( "--d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--num_train_workers", type=int, default=4, help="Number train workers") parser.add_argument("--nctx", type=int, default=256, help="Max input length") parser.add_argument("--file_type", default='json', help="Glob pattern for data") parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") parser.add_argument("--subword_model_file", type=str, help="The BPE model file", required=True) parser.add_argument("--subword_vocab_file", type=str, help="The BPE subword vocab", required=True) parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") parser.add_argument("--ffn_pdrop", type=float, default=0.0, help="Dropout in the dense stack") parser.add_argument("--layer_drop", type=float, default=0.0, help="LayerDrop to apply") parser.add_argument("--lr_scheduler", type=str, default='cosine', help="The type of learning rate decay scheduler") parser.add_argument("--lr_decay_steps", type=int, help="decay steps of lr scheduler") parser.add_argument("--lr_decay_rate", type=float, help="decay rate of lr scheduler") parser.add_argument("--lr_alpha", type=float, help="parameter alpha for cosine decay scheduler") parser.add_argument("--optim", default="adamw", type=str, help="Optimizer to use (defaults to adamw)") parser.add_argument("--lr", type=float, default=4.0e-4, help="Learning rate") parser.add_argument("--clip", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--weight_decay", type=float, default=1.0e-2, help="Weight decay") parser.add_argument("--epochs", type=int, default=32, help="Num training epochs") parser.add_argument( "--restart_from", type=str, help="Option allows you to restart from a previous checkpoint") parser.add_argument("--restart_tt", type=str, help="Optional param for legacy checkpoints", choices=['step', 'epoch', 'ignore']) parser.add_argument("--warmup_steps", type=int, default=10000, help="Num warmup steps") parser.add_argument("--saves_per_epoch", type=int, default=10, help="The number of checkpoints to save per epoch") parser.add_argument("--mlm", type=str2bool, default=True, help="Use Masked Language Model (MLM) objective") parser.add_argument("--preprocessed", type=str2bool, default=True, help="Has the data already been preprocessed?") parser.add_argument( '--rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument( '--rpr_value_on', type=str2bool, default=True, help= "In relative attention, whether add positional correction to values in addition to the " "correction to attention matrix") parser.add_argument("--windowed_ra", type=str2bool, default=False, help="whether prevent attention beyond rpr_k") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument("--distributed", type=str2bool, default=False, help="Are we doing distributed training?") parser.add_argument( "--local_rank", type=int, default=-1, help= "Local rank for distributed training (-1 means use the environment variables to find)" ) args = parser.parse_args() if args.basedir is None: args.basedir = 'lm-{}-bpe-{}'.format(args.dataset_key, os.getpid()) logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) num_gpus = get_num_gpus_multiworker() args.distributed = args.distributed or num_gpus > 1 logger.info(f"Using {num_gpus} GPUs in this job.") do_on_demand_masking = args.mlm and not args.preprocessed if do_on_demand_masking: logger.info(f"On-demand masking is turned on") if args.distributed: args.device, updated_local_rank = init_distributed(args.local_rank) args.local_rank = updated_local_rank if args.file_type == 'tfrecord': reader_type = 'tfrecord' elif args.preprocessed: reader_type = 'preprocessed' else: reader_type = 'lang' reader = MultiFileDatasetReader( src_nctx=args.nctx, model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, file_type=args.file_type, reader_type=reader_type, record_keys=['x', 'y'] if args.mlm else ['x']) # This looks a bit funny but the streaming reader ignores our vocab and gives us the one from the subword_model # However, we do need to get counts from our dataset for validation so we can calculate the perplexity vocab = reader.build_vocab([args.valid_file]) # If we are not using chars, then use 'x' for both input and output preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=args.d_model, known_vocab=vocab['x'], preserve_vocab_indices=True, embed_type=args.embed_type) vocabs = preproc_data['vocab'] os.makedirs(args.basedir, exist_ok=True) # We want to make sure to save our input vocab into the basedir for reuse later write_json(vocabs, os.path.join(args.basedir, 'vocabs.json')) embeddings = {'x': preproc_data['embeddings']} logger.info("Loaded embeddings") train_set = reader.load(args.train_file, vocabs) valid_set = reader.load(args.valid_file, vocabs, distribute=False, shuffle=False) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=args.num_train_workers) valid_loader = DataLoader(valid_set, batch_size=args.batch_size) logger.info("Loaded datasets") logger.info("Using embedding type [%s]", args.embed_type) if args.mlm: mask_from = vocabs vocab_size = len(mask_from) mask_value = mask_from.get("[MASK]") if mask_value == -1: logger.error( "We could not find a suitable masking token in the vocab") return if len(args.rpr_k) == 0 or args.rpr_k[0] < 1: rpr_k = None elif len(args.rpr_k) == 1: rpr_k = args.rpr_k[0] else: rpr_k = args.rpr_k TLM = TransformerMaskedLanguageModel if args.mlm else TransformerLanguageModel model = TLM.create(embeddings, hsz=args.d_model, d_ff=args.d_ff, tie_weights=True, dropout=args.dropout, gpu=False, num_heads=args.num_heads, layers=args.num_layers, rpr_k=rpr_k, d_k=args.d_k, ffn_pdrop=args.ffn_pdrop, windowed_ra=args.windowed_ra, rpr_value_on=args.rpr_value_on, layer_drop=args.layer_drop, src_keys=['x'], tgt_key='x') model.to(args.device) loss_function = model.create_loss() loss_function.to(args.device) logger.info("Loaded model and loss") steps_per_epoch = len(train_loader) // num_gpus valid_steps = len(valid_loader) update_on = steps_per_epoch // args.saves_per_epoch report_on = max(10, update_on) // 10 logger.info( f"Steps per epoch per GPU: {steps_per_epoch}. Saving checkpoint every {update_on} steps." ) lr_decay = get_lr_decay(args.lr_scheduler, args.lr, steps_per_epoch, args.epochs, logger, decay_steps=args.lr_decay_steps, decay_rate=args.lr_decay_rate, alpha=args.lr_alpha) linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr) lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=args.lr) global_step = 0 start_epoch = 0 if args.restart_from: if args.restart_from.endswith('npz'): load_tlm_npz(model, args.restart_from) else: model.load_state_dict(torch.load(args.restart_from)) vec = args.restart_from.split("-") if args.restart_tt: tick_type = args.restart_tt else: tick_type = vec[-2] step_num = int(vec[-1].split(".")[0]) if tick_type == 'epoch': start_epoch = step_num global_step = start_epoch * steps_per_epoch elif tick_type == 'step': start_epoch = step_num // steps_per_epoch global_step = step_num else: logger.warning( f"The previous tick was {step_num} but command-line specifies to ignore, setting to 0" ) logger.info( "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d", args.restart_from, global_step, start_epoch + 1) optimizer = OptimizerManager(model, global_step, optim=args.optim, lr=args.lr, lr_function=lr_sched, weight_decay=args.weight_decay) logger.info("Model has {:,} parameters".format( sum(p.numel() for p in model.parameters() if p.requires_grad))) # Prepare model for distributed training if needed if args.distributed: # This program assume pure data parallelism, each model is on a single gpu # If we wanted to support model and data parallelism we would need to update # the selection of gpus based on rank, it would need to select multiple ids # based on rank, here we select only a single gpu and use it for input and # output. model = DistributedDataParallel(model, device_ids=[args.device], output_device=args.device) logger.info("Model located on %s", args.device) model_base = os.path.join(args.basedir, 'checkpoint') steps = global_step timer = Timer() for epoch in range(start_epoch, args.epochs): avg_loss = Average('average_train_loss') metrics = {} optimizer.zero_grad() timer.start() model.train() train_itr = iter(train_loader) for i in range(steps_per_epoch): batch = next(train_itr) steps += 1 x, y = batch inputs = x.to(args.device) labels = y.to(args.device) if do_on_demand_masking: inputs, labels, _ = on_demand_mlm_masking( inputs, labels, mask_value, vocab_size) inputs = {'x': inputs} labels = labels.transpose(0, 1).contiguous() logits = model(inputs, None)[0].transpose(0, 1).contiguous() if args.mlm: loss = loss_function(logits, labels) else: shift_logits = logits[:-1] shift_labels = labels[1:] loss = loss_function(shift_logits, shift_labels) loss.backward() avg_loss.update(loss.item()) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() optimizer.zero_grad() if (i + 1) % report_on == 0: logging.info(avg_loss) if (i + 1) % update_on == 0 and args.local_rank < 1: elapsed = timer.elapsed(True) logging.info('elapsed time this epoch %d min', elapsed) logging.info('elapsed step time %f steps/min', i / elapsed) logging.info('LR: %f', optimizer.current_lr) save_checkpoint(model, model_base, steps, tick_type='step') # How much time elapsed in minutes elapsed = timer.elapsed(True) train_token_loss = avg_loss.avg # This is the average training token-level loss across all machines # This is the token-level training perplexity train_token_ppl = math.exp(train_token_loss) metrics['train_elapsed_min'] = elapsed metrics['average_train_loss'] = train_token_loss metrics['train_ppl'] = train_token_ppl if args.local_rank < 1: avg_valid_loss = Average('average_valid_loss') timer.start() model.eval() valid_itr = iter(valid_loader) for j in range(valid_steps): batch = next(valid_itr) with torch.no_grad(): x, y = batch inputs = x.to(args.device) labels = y.to(args.device) if do_on_demand_masking: inputs, labels, _ = on_demand_mlm_masking( inputs, labels, mask_value, vocab_size) inputs = {'x': inputs} labels = labels.transpose(0, 1).contiguous() logits = model(inputs, None)[0].transpose(0, 1).contiguous() if args.mlm: loss = loss_function(logits, labels) else: shift_logits = logits[:-1] shift_labels = labels[1:] loss = loss_function(shift_logits, shift_labels) avg_valid_loss.update(loss.item()) valid_token_loss = avg_valid_loss.avg valid_token_ppl = math.exp(valid_token_loss) metrics['valid_elapsed_min'] = timer.elapsed(True) metrics['average_valid_loss'] = valid_token_loss metrics['average_valid_word_ppl'] = valid_token_ppl logger.info(metrics) save_checkpoint(model, model_base, epoch, save_npz=True)
def run(basedir=None, train_file=None, valid_file=None, dataset_key='tlm', embed_type='default', d_model=512, d_ff=2048, d_k=None, num_heads=8, num_layers=8, num_train_workers=4, nctx=256, file_type='json', batch_size=256, subword_model_file=None, subword_vocab_file=None, dropout=0.1, ffn_pdrop=0.0, layer_drop=0.0, lr_scheduler='cosine', lr_decay_steps=None, lr_decay_rate=None, lr_alpha=0.0, optim='adamw', lr=4.0e-4, clip=1.0, weight_decay=1.0e-2, epochs=32, restart_from=None, restart_tt=None, warmup_steps=10000, saves_per_epoch=10, mlm=True, preprocessed=True, rpr_k=[8], rpr_value_on=False, windowed_ra=False, device="cuda", distributed=False, local_rank=-1, extra_tokens=["[CLS]", "[MASK]"], do_early_stopping=False, model_type='transformer-mlm', modules=[], ra_type=None, transformer_type=None, **kwargs): if basedir is None: basedir = 'lm-{}-bpe-{}'.format(dataset_key, os.getpid()) logging.basicConfig( level=logging.INFO if local_rank in [-1, 0] else logging.WARN) for module in modules: import_user_module(module) num_gpus = get_num_gpus_multiworker() distributed = distributed or num_gpus > 1 logger.info(f"Using {num_gpus} GPUs in this job.") do_on_demand_masking = mlm and not preprocessed if do_on_demand_masking: logger.info(f"On-demand masking is turned on") if distributed: device, updated_local_rank = init_distributed(local_rank) local_rank = updated_local_rank if file_type == 'tfrecord': reader_type = 'tfrecord' elif preprocessed: reader_type = 'preprocessed' else: reader_type = 'lang' reader = MultiFileDatasetReader(src_nctx=nctx, model_file=subword_model_file, vocab_file=subword_vocab_file, file_type=file_type, reader_type=reader_type, record_keys=['x', 'y'] if mlm else ['x'], extra_tokens=extra_tokens) # This looks a bit funny but the streaming reader ignores our vocab and gives us the one from the subword_model # However, we do need to get counts from our dataset for validation so we can calculate the perplexity vocab = reader.build_vocab([valid_file]) # If we are not using chars, then use 'x' for both input and output preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=d_model, known_vocab=vocab['x'], preserve_vocab_indices=True, embed_type=embed_type) vocabs = preproc_data['vocab'] os.makedirs(basedir, exist_ok=True) # We want to make sure to save our input vocab into the basedir for reuse later write_json(vocabs, os.path.join(basedir, 'vocabs.json')) embeddings = {'x': preproc_data['embeddings']} logger.info("Loaded embeddings") train_set = reader.load(train_file, vocabs) valid_set = reader.load(valid_file, vocabs, distribute=False, shuffle=False) train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=num_train_workers) valid_loader = DataLoader(valid_set, batch_size=batch_size) logger.info("Loaded datasets") logger.info("Using embedding type [%s]", embed_type) if 'mlm' in model_type: mask_from = vocabs vocab_size = len(mask_from) mask_value = mask_from.get("[MASK]") if mask_value == -1: logger.error( "We could not find a suitable masking token in the vocab") return if len(rpr_k) == 0 or rpr_k[0] < 1: rpr_k = None elif len(rpr_k) == 1: rpr_k = None if rpr_k[0] == 0 else rpr_k[0] if ra_type != None and ra_type != 'shaw' and rpr_k is not None: print( f"Relative attention mismatch. You requested {ra_type} with rpr set. Setting it to 0" ) rpr_k = None model = create_lang_model( embeddings, hsz=d_model, nctx=nctx, # Only for gMLP d_ff=d_ff, tie_weights=True, dropout=dropout, gpu=False, num_heads=num_heads, layers=num_layers, rpr_k=rpr_k, d_k=d_k, ffn_pdrop=ffn_pdrop, windowed_ra=windowed_ra, rpr_value_on=rpr_value_on, layer_drop=layer_drop, model_type=model_type, ra_type=ra_type, transformer_type=transformer_type, src_keys=['x'], tgt_key='x') model.to(device) loss_function = model.create_loss() loss_function.to(device) logger.info("Loaded model and loss") steps_per_epoch = len(train_loader) // num_gpus update_on = steps_per_epoch // saves_per_epoch report_on = max(10, update_on) // 10 logger.info( f"Steps per epoch per GPU: {steps_per_epoch}. Saving checkpoint every {update_on} steps." ) lr_decay = get_lr_decay(lr_scheduler, lr, steps_per_epoch, epochs, logger, decay_steps=lr_decay_steps, decay_rate=lr_decay_rate, alpha=lr_alpha) linear_warmup = WarmupLinearSchedulerPyTorch(warmup_steps, lr=lr) lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=lr) global_step = 0 start_epoch = 0 if restart_from: if restart_from.endswith('npz'): load_tlm_npz(model, restart_from) else: model.load_state_dict(torch.load(restart_from)) vec = restart_from.split("-") if restart_tt: tick_type = restart_tt else: tick_type = vec[-2] step_num = int(vec[-1].split(".")[0]) if tick_type == 'epoch': start_epoch = step_num global_step = start_epoch * steps_per_epoch elif tick_type == 'step': start_epoch = step_num // steps_per_epoch global_step = step_num else: logger.warning( f"The previous tick was {step_num} but command-line specifies to ignore, setting to 0" ) logger.info( "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d", restart_from, global_step, start_epoch + 1) optimizer = OptimizerManager(model, global_step, optim=optim, lr=lr, lr_function=lr_sched, weight_decay=weight_decay) logger.info("Model has {:,} parameters".format( sum(p.numel() for p in model.parameters() if p.requires_grad))) # Prepare model for distributed training if needed if distributed: # This program assume pure data parallelism, each model is on a single gpu # If we wanted to support model and data parallelism we would need to update # the selection of gpus based on rank, it would need to select multiple ids # based on rank, here we select only a single gpu and use it for input and # output. model = DistributedDataParallel(model, device_ids=[device], output_device=device, find_unused_parameters=True) logger.info("Model located on %s", device) model_base = os.path.join(basedir, 'checkpoint') steps = global_step best_valid_loss = np.inf timer = Timer() for epoch in range(start_epoch, epochs): avg_loss = Average('average_train_loss') metrics = {} optimizer.zero_grad() timer.start() model.train() train_itr = iter(train_loader) for i in range(steps_per_epoch): batch = next(train_itr) steps += 1 x, y = batch inputs = x.to(device) labels = y.to(device) if do_on_demand_masking: inputs, labels, _ = on_demand_mlm_masking( inputs, labels, mask_value, vocab_size) inputs = {'x': inputs} labels = labels.contiguous() logits = model(inputs, None)[0].contiguous() if mlm: loss = loss_function(logits, labels) else: shift_logits = logits[:, -1] shift_labels = labels[:, 1:] loss = loss_function(shift_logits, shift_labels) loss.backward() avg_loss.update(loss.item()) torch.nn.utils.clip_grad_norm_(model.parameters(), clip) optimizer.step() optimizer.zero_grad() if (i + 1) % report_on == 0: logging.info(avg_loss) if (i + 1) % update_on == 0 and local_rank < 1: elapsed = timer.elapsed(True) logging.info('elapsed time this epoch %d min', elapsed) logging.info('elapsed step time %f steps/min', i / elapsed) logging.info('LR: %f', optimizer.current_lr) if not do_early_stopping: save_checkpoint(model, model_base, steps, tick_type='step') else: valid_token_loss = validate(model, loss_function, valid_loader, avg_loss, timer, metrics, do_on_demand_masking, mlm, mask_value, vocab_size, device) if valid_token_loss < best_valid_loss: best_valid_loss = valid_token_loss logger.info( f"New best valid loss: {best_valid_loss}. Saving checkpoint..." ) save_checkpoint(model, model_base, steps, tick_type='step') model.train() if not do_early_stopping: _ = validate(model, loss_function, valid_loader, avg_loss, timer, metrics, do_on_demand_masking, mlm, mask_value, vocab_size, device) save_checkpoint(model, model_base, epoch, tick_type='epoch')