def main(): parser = argparse.ArgumentParser() ## Required parameters parser.add_argument( "--input_dir", default=None, type=str, required=True, help="The input data dir. Should contain .hdf5 files for the task.") parser.add_argument("--config_file", default="bert_config.json", type=str, required=False, help="The BERT model config") ckpt_group = parser.add_mutually_exclusive_group(required=True) ckpt_group.add_argument("--ckpt_dir", default=None, type=str, help="The ckpt directory, e.g. /results") ckpt_group.add_argument("--ckpt_path", default=None, type=str, help="Path to the specific checkpoint") group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--eval', dest='do_eval', action='store_true') group.add_argument('--prediction', dest='do_eval', action='store_false') ## Other parameters parser.add_argument( "--bert_model", default="bert-large-uncased", type=str, required=False, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument( "--max_seq_length", default=512, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument( "--max_predictions_per_seq", default=80, type=int, help="The maximum total of masked tokens in input sequence") parser.add_argument("--ckpt_step", default=-1, type=int, required=False, help="The model checkpoint iteration, e.g. 1000") parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for training.") parser.add_argument( "--max_steps", default=-1, type=int, help= "Total number of eval steps to perform, otherwise use full dataset") parser.add_argument("--no_cuda", default=False, action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--fp16', default=False, action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument("--log_path", help="Out file for DLLogger", default="/workspace/dllogger_inference.out", type=str) args = parser.parse_args() if 'LOCAL_RANK' in os.environ: args.local_rank = int(os.environ['LOCAL_RANK']) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl', init_method='env://') if is_main_process(): dllogger.init(backends=[ dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE, filename=args.log_path), dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE, step_format=format_step) ]) else: dllogger.init(backends=[]) n_gpu = torch.cuda.device_count() if n_gpu > 1: assert (args.local_rank != -1 ) # only use torch.distributed for multi-gpu dllogger.log( step= "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16), data={}) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Prepare model config = BertConfig.from_json_file(args.config_file) # Padding for divisibility by 8 if config.vocab_size % 8 != 0: config.vocab_size += 8 - (config.vocab_size % 8) model = BertForPreTraining(config) if args.ckpt_dir: if args.ckpt_step == -1: #retrieve latest model model_names = [ f for f in os.listdir(args.ckpt_dir) if f.endswith(".pt") ] args.ckpt_step = max([ int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names ]) dllogger.log(step="load model saved at iteration", data={"number": args.ckpt_step}) model_file = os.path.join(args.ckpt_dir, "ckpt_" + str(args.ckpt_step) + ".pt") else: model_file = args.ckpt_path state_dict = torch.load(model_file, map_location="cpu")["model"] model.load_state_dict(state_dict, strict=False) if args.fp16: model.half( ) # all parameters and buffers are converted to half precision model.to(device) multi_gpu_training = args.local_rank != -1 and torch.distributed.is_initialized( ) if multi_gpu_training: model = DDP(model) files = [ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if os.path.isfile(os.path.join(args.input_dir, f)) and 'test' in f ] files.sort() dllogger.log(step="***** Running Inference *****", data={}) dllogger.log(step=" Inference batch", data={"size": args.eval_batch_size}) model.eval() nb_instances = 0 max_steps = args.max_steps if args.max_steps > 0 else np.inf global_step = 0 total_samples = 0 begin_infer = time.time() with torch.no_grad(): if args.do_eval: final_loss = 0.0 # for data_file in files: dllogger.log(step="Opening ", data={"file": data_file}) dataset = pretraining_dataset( input_file=data_file, max_pred_length=args.max_predictions_per_seq) if not multi_gpu_training: train_sampler = RandomSampler(dataset) datasetloader = DataLoader(dataset, sampler=train_sampler, batch_size=args.eval_batch_size, num_workers=4, pin_memory=True) else: train_sampler = DistributedSampler(dataset) datasetloader = DataLoader(dataset, sampler=train_sampler, batch_size=args.eval_batch_size, num_workers=4, pin_memory=True) for step, batch in enumerate( tqdm(datasetloader, desc="Iteration")): if global_step > max_steps: break batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch #\ loss = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, masked_lm_labels=masked_lm_labels, next_sentence_label=next_sentence_labels) final_loss += loss.item() global_step += 1 total_samples += len(datasetloader) torch.cuda.empty_cache() if global_step > max_steps: break final_loss /= global_step if multi_gpu_training: final_loss = torch.tensor(final_loss, device=device) dist.all_reduce(final_loss) final_loss /= torch.distributed.get_world_size() if (not multi_gpu_training or (multi_gpu_training and torch.distributed.get_rank() == 0)): dllogger.log(step="Inference Loss", data={"final_loss": final_loss.item()}) else: # inference # if multi_gpu_training: # torch.distributed.barrier() # start_t0 = time.time() for data_file in files: dllogger.log(step="Opening ", data={"file": data_file}) dataset = pretraining_dataset( input_file=data_file, max_pred_length=args.max_predictions_per_seq) if not multi_gpu_training: train_sampler = RandomSampler(dataset) datasetloader = DataLoader(dataset, sampler=train_sampler, batch_size=args.eval_batch_size, num_workers=4, pin_memory=True) else: train_sampler = DistributedSampler(dataset) datasetloader = DataLoader(dataset, sampler=train_sampler, batch_size=args.eval_batch_size, num_workers=4, pin_memory=True) for step, batch in enumerate( tqdm(datasetloader, desc="Iteration")): if global_step > max_steps: break batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch #\ lm_logits, nsp_logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, masked_lm_labels=None, next_sentence_label=None) nb_instances += input_ids.size(0) global_step += 1 total_samples += len(datasetloader) torch.cuda.empty_cache() if global_step > max_steps: break # if multi_gpu_training: # torch.distributed.barrier() if (not multi_gpu_training or (multi_gpu_training and torch.distributed.get_rank() == 0)): dllogger.log(step="Done Inferring on samples", data={}) end_infer = time.time() dllogger.log(step="Inference perf", data={ "inference_sequences_per_second": total_samples * args.eval_batch_size / (end_infer - begin_infer) })
def main(): print("IN NEW MAIN XD\n") parser = argparse.ArgumentParser() ## Required parameters parser.add_argument( "--input_dir", default=None, type=str, required=True, help="The input data dir. Should contain .hdf5 files for the task.") parser.add_argument("--config_file", default="bert_config.json", type=str, required=False, help="The BERT model config") parser.add_argument("--ckpt_dir", default=None, type=str, required=True, help="The ckpt directory, e.g. /results") group = parser.add_mutually_exclusive_group(required=True) group.add_argument('--eval', dest='do_eval', action='store_true') group.add_argument('--prediction', dest='do_eval', action='store_false') ## Other parameters parser.add_argument( "--bert_model", default="bert-large-uncased", type=str, required=False, help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." ) parser.add_argument( "--max_seq_length", default=512, type=int, help= "The maximum total input sequence length after WordPiece tokenization. \n" "Sequences longer than this will be truncated, and sequences shorter \n" "than this will be padded.") parser.add_argument( "--max_predictions_per_seq", default=80, type=int, help="The maximum total of masked tokens in input sequence") parser.add_argument("--ckpt_step", default=-1, type=int, required=False, help="The model checkpoint iteration, e.g. 1000") parser.add_argument("--eval_batch_size", default=8, type=int, help="Total batch size for training.") parser.add_argument( "--max_steps", default=-1, type=int, help= "Total number of eval steps to perform, otherwise use full dataset") parser.add_argument("--no_cuda", default=False, action='store_true', help="Whether not to use CUDA when available") parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument( '--fp16', default=False, action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") args = parser.parse_args() if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl', init_method='env://') n_gpu = torch.cuda.device_count() if n_gpu > 1: assert (args.local_rank != -1 ) # only use torch.distributed for multi-gpu logger.info("device %s n_gpu %d distributed inference %r", device, n_gpu, bool(args.local_rank != -1)) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if n_gpu > 0: torch.cuda.manual_seed_all(args.seed) # Prepare model config = BertConfig.from_json_file(args.config_file) model = BertForPreTraining(config) if args.ckpt_step == -1: #retrieve latest model model_names = [ f for f in os.listdir(args.ckpt_dir) if f.endswith(".model") ] args.ckpt_step = max([ int(x.split('.model')[0].split('_')[1].strip()) for x in model_names ]) print("load model saved at iteraton", args.ckpt_step) model_file = os.path.join(args.ckpt_dir, "ckpt_" + str(args.ckpt_step) + ".model") state_dict = torch.load(model_file, map_location="cpu") model.load_state_dict(state_dict, strict=False) if args.fp16: model.half( ) # all parameters and buffers are converted to half precision model.to(device) multi_gpu_training = args.local_rank != -1 and torch.distributed.is_initialized( ) if multi_gpu_training: model = DDP(model) files = [ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if os.path.isfile(os.path.join(args.input_dir, f)) ] files.sort() logger.info("***** Running evaluation *****") logger.info(" Batch size = %d", args.eval_batch_size) model.eval() print("Evaluation. . .") nb_instances = 0 max_steps = args.max_steps if args.max_steps > 0 else np.inf global_step = 0 with torch.no_grad(): if args.do_eval: final_loss = 0.0 # for data_file in files: logger.info("file %s" % (data_file)) dataset = pretraining_dataset( input_file=data_file, max_pred_length=args.max_predictions_per_seq) if not multi_gpu_training: train_sampler = RandomSampler(dataset) datasetloader = DataLoader(dataset, sampler=train_sampler, batch_size=args.eval_batch_size, num_workers=4, pin_memory=True) else: train_sampler = DistributedSampler(dataset) datasetloader = DataLoader(dataset, sampler=train_sampler, batch_size=args.eval_batch_size, num_workers=4, pin_memory=True) for step, batch in enumerate( tqdm(datasetloader, desc="Iteration")): if global_step > max_steps: break batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch #\ loss = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, masked_lm_labels=masked_lm_labels, next_sentence_label=next_sentence_labels) final_loss += loss global_step += 1 torch.cuda.empty_cache() if global_step > max_steps: break final_loss /= global_step if multi_gpu_training: final_loss /= torch.distributed.get_world_size() dist.all_reduce(final_loss) if (not multi_gpu_training or (multi_gpu_training and torch.distributed.get_rank() == 0)): logger.info("Finished: Final Loss = {}".format(final_loss)) else: # inference # if multi_gpu_training: # torch.distributed.barrier() # start_t0 = time.time() for data_file in files: logger.info("file %s" % (data_file)) dataset = pretraining_dataset( input_file=data_file, max_pred_length=args.max_predictions_per_seq) if not multi_gpu_training: train_sampler = RandomSampler(dataset) datasetloader = DataLoader(dataset, sampler=train_sampler, batch_size=args.eval_batch_size, num_workers=4, pin_memory=True) else: train_sampler = DistributedSampler(dataset) datasetloader = DataLoader(dataset, sampler=train_sampler, batch_size=args.eval_batch_size, num_workers=4, pin_memory=True) for step, batch in enumerate( tqdm(datasetloader, desc="Iteration")): if global_step > max_steps: break batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch #\ lm_logits, nsp_logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, masked_lm_labels=None, next_sentence_label=None) nb_instances += input_ids.size(0) global_step += 1 torch.cuda.empty_cache() if global_step > max_steps: break # if multi_gpu_training: # torch.distributed.barrier() if (not multi_gpu_training or (multi_gpu_training and torch.distributed.get_rank() == 0)): logger.info("Finished")
def inference(args): start_t = time.time() bert_module = BertForPreTraining( args.vocab_size, args.seq_length, args.hidden_size, args.num_hidden_layers, args.num_attention_heads, args.intermediate_size, nn.GELU(), args.hidden_dropout_prob, args.attention_probs_dropout_prob, args.max_position_embeddings, args.type_vocab_size, args.vocab_size, ) end_t = time.time() print("Initialize model using time: {:.3f}s".format(end_t - start_t)) start_t = time.time() if args.use_lazy_model: from utils.compare_lazy_outputs import load_params_from_lazy load_params_from_lazy( bert_module.state_dict(), args.model_path, ) else: bert_module.load_state_dict(flow.load(args.model_path)) end_t = time.time() print("Loading parameters using time: {:.3f}s".format(end_t - start_t)) bert_module.eval() bert_module.to(args.device) class BertEvalGraph(nn.Graph): def __init__(self): super().__init__() self.bert = bert_module def build(self, input_ids, input_masks, segment_ids): input_ids = input_ids.to(device=args.device) input_masks = input_masks.to(device=args.device) segment_ids = segment_ids.to(device=args.device) with flow.no_grad(): # 1. forward the next_sentence_prediction and masked_lm model _, seq_relationship_scores = self.bert(input_ids, input_masks, segment_ids) return seq_relationship_scores bert_eval_graph = BertEvalGraph() start_t = time.time() inputs = [np.random.randint(0, 20, size=args.seq_length)] inputs = flow.Tensor(inputs, dtype=flow.int64, device=flow.device(args.device)) mask = flow.cast(inputs > 0, dtype=flow.int64) segment_info = flow.zeros_like(inputs) prediction = bert_eval_graph(inputs, mask, segment_info) print(prediction.numpy()) end_t = time.time() print("Inference using time: {:.3f}".format(end_t - start_t))
def main(): args = get_config() if args.with_cuda: device = flow.device("cuda") else: device = flow.device("cpu") print("Creating Dataloader") train_data_loader = OfRecordDataLoader( ofrecord_dir=args.ofrecord_path, mode="train", dataset_size=args.train_dataset_size, batch_size=args.train_batch_size, data_part_num=args.train_data_part, seq_length=args.seq_length, max_predictions_per_seq=args.max_predictions_per_seq, consistent=False, ) test_data_loader = OfRecordDataLoader( ofrecord_dir=args.ofrecord_path, mode="test", dataset_size=1024, batch_size=args.val_batch_size, data_part_num=4, seq_length=args.seq_length, max_predictions_per_seq=args.max_predictions_per_seq, consistent=False, ) print("Building BERT Model") hidden_size = 64 * args.num_attention_heads intermediate_size = 4 * hidden_size bert_model = BertForPreTraining( args.vocab_size, args.seq_length, hidden_size, args.num_hidden_layers, args.num_attention_heads, intermediate_size, nn.GELU(), args.hidden_dropout_prob, args.attention_probs_dropout_prob, args.max_position_embeddings, args.type_vocab_size, ) # Load the same initial parameters with lazy model. # from utils.compare_lazy_outputs import load_params_from_lazy # load_params_from_lazy( # bert_model.state_dict(), # "../../OneFlow-Benchmark/LanguageModeling/BERT/initial_model", # ) bert_model = bert_model.to(device) if args.use_ddp: bert_model = ddp(bert_model) optimizer = build_optimizer( args.optim_name, bert_model, args.lr, args.weight_decay, weight_decay_excludes=["bias", "LayerNorm", "layer_norm"], clip_grad_max_norm=1, clip_grad_norm_type=2.0, ) steps = args.epochs * len(train_data_loader) warmup_steps = int(steps * args.warmup_proportion) lr_scheduler = PolynomialLR(optimizer, steps=steps, end_learning_rate=0.0) lr_scheduler = flow.optim.lr_scheduler.WarmUpLR(lr_scheduler, warmup_factor=0, warmup_iters=warmup_steps, warmup_method="linear") ns_criterion = nn.CrossEntropyLoss(reduction="mean") mlm_criterion = nn.CrossEntropyLoss(reduction="none") def get_masked_lm_loss( logit_blob, masked_lm_positions, masked_lm_labels, label_weights, max_prediction_per_seq, ): # gather valid position indices logit_blob = flow.gather( logit_blob, index=masked_lm_positions.unsqueeze(2).repeat( 1, 1, args.vocab_size), dim=1, ) logit_blob = flow.reshape(logit_blob, [-1, args.vocab_size]) label_id_blob = flow.reshape(masked_lm_labels, [-1]) # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. pre_example_loss = mlm_criterion(logit_blob, label_id_blob) pre_example_loss = flow.reshape(pre_example_loss, [-1, max_prediction_per_seq]) numerator = flow.sum(pre_example_loss * label_weights) denominator = flow.sum(label_weights) + 1e-5 loss = numerator / denominator return loss train_total_losses = [] for epoch in range(args.epochs): metric = Metric( desc="bert pretrain", print_steps=args.loss_print_every_n_iters, batch_size=args.train_batch_size, keys=["total_loss", "mlm_loss", "nsp_loss", "pred_acc"], ) # Train bert_model.train() for step in range(len(train_data_loader)): bert_outputs = pretrain( train_data_loader, bert_model, ns_criterion, partial( get_masked_lm_loss, max_prediction_per_seq=args.max_predictions_per_seq, ), optimizer, lr_scheduler, ) if flow.env.get_rank() == 0: metric.metric_cb(step, epoch=epoch)(bert_outputs) train_total_losses.append(bert_outputs["total_loss"]) # Eval bert_model.eval() val_acc = validation(epoch, test_data_loader, bert_model, args.val_print_every_n_iters) save_model(bert_model, args.checkpoint_path, epoch, val_acc, False)
def main(): args = get_config() world_size = flow.env.get_world_size() if args.train_global_batch_size is None: args.train_global_batch_size = args.train_batch_size * world_size else: assert args.train_global_batch_size % args.train_batch_size == 0 if args.val_global_batch_size is None: args.val_global_batch_size = args.val_batch_size * world_size else: assert args.val_global_batch_size % args.val_batch_size == 0 flow.boxing.nccl.set_fusion_threshold_mbytes(args.nccl_fusion_threshold_mb) flow.boxing.nccl.set_fusion_max_ops_num(args.nccl_fusion_max_ops) if args.with_cuda: device = "cuda" else: device = "cpu" print("Device is: ", device) print("Creating Dataloader") train_data_loader = OfRecordDataLoader( ofrecord_dir=args.ofrecord_path, mode="train", dataset_size=args.train_dataset_size, batch_size=args.train_global_batch_size, data_part_num=args.train_data_part, seq_length=args.seq_length, max_predictions_per_seq=args.max_predictions_per_seq, consistent=args.use_consistent, ) test_data_loader = OfRecordDataLoader( ofrecord_dir=args.ofrecord_path, mode="test", dataset_size=1024, batch_size=args.val_global_batch_size, data_part_num=4, seq_length=args.seq_length, max_predictions_per_seq=args.max_predictions_per_seq, consistent=args.use_consistent, ) print("Building BERT Model") hidden_size = 64 * args.num_attention_heads intermediate_size = 4 * hidden_size bert_model = BertForPreTraining( args.vocab_size, args.seq_length, hidden_size, args.num_hidden_layers, args.num_attention_heads, intermediate_size, nn.GELU(), args.hidden_dropout_prob, args.attention_probs_dropout_prob, args.max_position_embeddings, args.type_vocab_size, ) # Load the same initial parameters with lazy model. # from utils.compare_lazy_outputs import load_params_from_lazy # load_params_from_lazy( # bert_model.state_dict(), # "../../OneFlow-Benchmark/LanguageModeling/BERT/initial_model", # ) assert id(bert_model.cls.predictions.decoder.weight) == id( bert_model.bert.embeddings.word_embeddings.weight ) ns_criterion = nn.CrossEntropyLoss(reduction="mean") mlm_criterion = nn.CrossEntropyLoss(reduction="none") if args.use_consistent: placement = flow.env.all_device_placement("cuda") bert_model = bert_model.to_consistent( placement=placement, sbp=flow.sbp.broadcast ) else: bert_model.to(device) ns_criterion.to(device) mlm_criterion.to(device) optimizer = build_optimizer( args.optim_name, bert_model, args.lr, args.weight_decay, weight_decay_excludes=["bias", "LayerNorm", "layer_norm"], clip_grad_max_norm=1, clip_grad_norm_type=2.0, ) steps = args.epochs * len(train_data_loader) warmup_steps = int(steps * args.warmup_proportion) lr_scheduler = PolynomialLR(optimizer, steps=steps, end_learning_rate=0.0) lr_scheduler = flow.optim.lr_scheduler.WarmUpLR( lr_scheduler, warmup_factor=0, warmup_iters=warmup_steps, warmup_method="linear" ) def get_masked_lm_loss( logit, masked_lm_labels, label_weights, max_predictions_per_seq, ): label_id = flow.reshape(masked_lm_labels, [-1]) # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. pre_example_loss = mlm_criterion(logit, label_id) pre_example_loss = flow.reshape(pre_example_loss, [-1, max_predictions_per_seq]) numerator = flow.sum(pre_example_loss * label_weights) denominator = flow.sum(label_weights) + 1e-5 loss = numerator / denominator return loss class BertGraph(nn.Graph): def __init__(self): super().__init__() self.bert = bert_model self.ns_criterion = ns_criterion self.masked_lm_criterion = partial( get_masked_lm_loss, max_predictions_per_seq=args.max_predictions_per_seq ) self.add_optimizer(optimizer, lr_sch=lr_scheduler) self._train_data_loader = train_data_loader if args.grad_acc_steps > 1: self.config.set_gradient_accumulation_steps(args.grad_acc_steps) if args.use_fp16: self.config.enable_amp(True) grad_scaler = flow.amp.GradScaler( init_scale=2 ** 30, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, ) self.set_grad_scaler(grad_scaler) self.config.allow_fuse_add_to_output(True) self.config.allow_fuse_model_update_ops(True) def build(self): ( input_ids, next_sentence_labels, input_mask, segment_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, ) = self._train_data_loader() input_ids = input_ids.to(device=device) input_mask = input_mask.to(device=device) segment_ids = segment_ids.to(device=device) next_sentence_labels = next_sentence_labels.to(device=device) masked_lm_ids = masked_lm_ids.to(device=device) masked_lm_positions = masked_lm_positions.to(device=device) masked_lm_weights = masked_lm_weights.to(device=device) # 1. forward the next_sentence_prediction and masked_lm model prediction_scores, seq_relationship_scores = self.bert( input_ids, segment_ids, input_mask, masked_lm_positions ) # 2-1. loss of is_next classification result next_sentence_loss = self.ns_criterion( seq_relationship_scores.reshape(-1, 2), next_sentence_labels.reshape(-1) ) masked_lm_loss = self.masked_lm_criterion( prediction_scores, masked_lm_ids, masked_lm_weights ) total_loss = masked_lm_loss + next_sentence_loss total_loss.backward() return ( seq_relationship_scores, next_sentence_labels, total_loss, masked_lm_loss, next_sentence_loss, ) bert_graph = BertGraph() class BertEvalGraph(nn.Graph): def __init__(self): super().__init__() self.bert = bert_model self._test_data_loader = test_data_loader self.config.allow_fuse_add_to_output(True) def build(self): ( input_ids, next_sent_labels, input_masks, segment_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, ) = self._test_data_loader() input_ids = input_ids.to(device=device) input_masks = input_masks.to(device=device) segment_ids = segment_ids.to(device=device) next_sent_labels = next_sent_labels.to(device=device) masked_lm_ids = masked_lm_ids.to(device=device) masked_lm_positions = masked_lm_positions.to(device) with flow.no_grad(): # 1. forward the next_sentence_prediction and masked_lm model _, seq_relationship_scores = self.bert( input_ids, input_masks, segment_ids ) return seq_relationship_scores, next_sent_labels bert_eval_graph = BertEvalGraph() train_total_losses = [] for epoch in range(args.epochs): metric = Metric( desc="bert pretrain", print_steps=args.loss_print_every_n_iters, batch_size=args.train_global_batch_size * args.grad_acc_steps, keys=["total_loss", "mlm_loss", "nsp_loss", "pred_acc"], ) # Train bert_model.train() for step in range(len(train_data_loader)): bert_outputs = pretrain(bert_graph, args.metric_local) if flow.env.get_rank() == 0: metric.metric_cb(step, epoch=epoch)(bert_outputs) train_total_losses.append(bert_outputs["total_loss"]) # Eval bert_model.eval() val_acc = validation( epoch, len(test_data_loader), bert_eval_graph, args.val_print_every_n_iters, args.metric_local, ) save_model(bert_model, args.checkpoint_path, epoch, val_acc, args.use_consistent)