def main(): args = get_config() world_size = flow.env.get_world_size() if args.train_global_batch_size is None: args.train_global_batch_size = args.train_batch_size * world_size else: assert args.train_global_batch_size % args.train_batch_size == 0 if args.val_global_batch_size is None: args.val_global_batch_size = args.val_batch_size * world_size else: assert args.val_global_batch_size % args.val_batch_size == 0 flow.boxing.nccl.set_fusion_threshold_mbytes(args.nccl_fusion_threshold_mb) flow.boxing.nccl.set_fusion_max_ops_num(args.nccl_fusion_max_ops) if args.with_cuda: device = "cuda" else: device = "cpu" print("Device is: ", device) print("Creating Dataloader") train_data_loader = OfRecordDataLoader( ofrecord_dir=args.ofrecord_path, mode="train", dataset_size=args.train_dataset_size, batch_size=args.train_global_batch_size, data_part_num=args.train_data_part, seq_length=args.seq_length, max_predictions_per_seq=args.max_predictions_per_seq, consistent=args.use_consistent, ) test_data_loader = OfRecordDataLoader( ofrecord_dir=args.ofrecord_path, mode="test", dataset_size=1024, batch_size=args.val_global_batch_size, data_part_num=4, seq_length=args.seq_length, max_predictions_per_seq=args.max_predictions_per_seq, consistent=args.use_consistent, ) print("Building BERT Model") hidden_size = 64 * args.num_attention_heads intermediate_size = 4 * hidden_size bert_model = BertForPreTraining( args.vocab_size, args.seq_length, hidden_size, args.num_hidden_layers, args.num_attention_heads, intermediate_size, nn.GELU(), args.hidden_dropout_prob, args.attention_probs_dropout_prob, args.max_position_embeddings, args.type_vocab_size, ) # Load the same initial parameters with lazy model. # from utils.compare_lazy_outputs import load_params_from_lazy # load_params_from_lazy( # bert_model.state_dict(), # "../../OneFlow-Benchmark/LanguageModeling/BERT/initial_model", # ) assert id(bert_model.cls.predictions.decoder.weight) == id( bert_model.bert.embeddings.word_embeddings.weight ) ns_criterion = nn.CrossEntropyLoss(reduction="mean") mlm_criterion = nn.CrossEntropyLoss(reduction="none") if args.use_consistent: placement = flow.env.all_device_placement("cuda") bert_model = bert_model.to_consistent( placement=placement, sbp=flow.sbp.broadcast ) else: bert_model.to(device) ns_criterion.to(device) mlm_criterion.to(device) optimizer = build_optimizer( args.optim_name, bert_model, args.lr, args.weight_decay, weight_decay_excludes=["bias", "LayerNorm", "layer_norm"], clip_grad_max_norm=1, clip_grad_norm_type=2.0, ) steps = args.epochs * len(train_data_loader) warmup_steps = int(steps * args.warmup_proportion) lr_scheduler = PolynomialLR(optimizer, steps=steps, end_learning_rate=0.0) lr_scheduler = flow.optim.lr_scheduler.WarmUpLR( lr_scheduler, warmup_factor=0, warmup_iters=warmup_steps, warmup_method="linear" ) def get_masked_lm_loss( logit, masked_lm_labels, label_weights, max_predictions_per_seq, ): label_id = flow.reshape(masked_lm_labels, [-1]) # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. pre_example_loss = mlm_criterion(logit, label_id) pre_example_loss = flow.reshape(pre_example_loss, [-1, max_predictions_per_seq]) numerator = flow.sum(pre_example_loss * label_weights) denominator = flow.sum(label_weights) + 1e-5 loss = numerator / denominator return loss class BertGraph(nn.Graph): def __init__(self): super().__init__() self.bert = bert_model self.ns_criterion = ns_criterion self.masked_lm_criterion = partial( get_masked_lm_loss, max_predictions_per_seq=args.max_predictions_per_seq ) self.add_optimizer(optimizer, lr_sch=lr_scheduler) self._train_data_loader = train_data_loader if args.grad_acc_steps > 1: self.config.set_gradient_accumulation_steps(args.grad_acc_steps) if args.use_fp16: self.config.enable_amp(True) grad_scaler = flow.amp.GradScaler( init_scale=2 ** 30, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, ) self.set_grad_scaler(grad_scaler) self.config.allow_fuse_add_to_output(True) self.config.allow_fuse_model_update_ops(True) def build(self): ( input_ids, next_sentence_labels, input_mask, segment_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, ) = self._train_data_loader() input_ids = input_ids.to(device=device) input_mask = input_mask.to(device=device) segment_ids = segment_ids.to(device=device) next_sentence_labels = next_sentence_labels.to(device=device) masked_lm_ids = masked_lm_ids.to(device=device) masked_lm_positions = masked_lm_positions.to(device=device) masked_lm_weights = masked_lm_weights.to(device=device) # 1. forward the next_sentence_prediction and masked_lm model prediction_scores, seq_relationship_scores = self.bert( input_ids, segment_ids, input_mask, masked_lm_positions ) # 2-1. loss of is_next classification result next_sentence_loss = self.ns_criterion( seq_relationship_scores.reshape(-1, 2), next_sentence_labels.reshape(-1) ) masked_lm_loss = self.masked_lm_criterion( prediction_scores, masked_lm_ids, masked_lm_weights ) total_loss = masked_lm_loss + next_sentence_loss total_loss.backward() return ( seq_relationship_scores, next_sentence_labels, total_loss, masked_lm_loss, next_sentence_loss, ) bert_graph = BertGraph() class BertEvalGraph(nn.Graph): def __init__(self): super().__init__() self.bert = bert_model self._test_data_loader = test_data_loader self.config.allow_fuse_add_to_output(True) def build(self): ( input_ids, next_sent_labels, input_masks, segment_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, ) = self._test_data_loader() input_ids = input_ids.to(device=device) input_masks = input_masks.to(device=device) segment_ids = segment_ids.to(device=device) next_sent_labels = next_sent_labels.to(device=device) masked_lm_ids = masked_lm_ids.to(device=device) masked_lm_positions = masked_lm_positions.to(device) with flow.no_grad(): # 1. forward the next_sentence_prediction and masked_lm model _, seq_relationship_scores = self.bert( input_ids, input_masks, segment_ids ) return seq_relationship_scores, next_sent_labels bert_eval_graph = BertEvalGraph() train_total_losses = [] for epoch in range(args.epochs): metric = Metric( desc="bert pretrain", print_steps=args.loss_print_every_n_iters, batch_size=args.train_global_batch_size * args.grad_acc_steps, keys=["total_loss", "mlm_loss", "nsp_loss", "pred_acc"], ) # Train bert_model.train() for step in range(len(train_data_loader)): bert_outputs = pretrain(bert_graph, args.metric_local) if flow.env.get_rank() == 0: metric.metric_cb(step, epoch=epoch)(bert_outputs) train_total_losses.append(bert_outputs["total_loss"]) # Eval bert_model.eval() val_acc = validation( epoch, len(test_data_loader), bert_eval_graph, args.val_print_every_n_iters, args.metric_local, ) save_model(bert_model, args.checkpoint_path, epoch, val_acc, args.use_consistent)
def main(): args = get_config() if args.with_cuda: device = flow.device("cuda") else: device = flow.device("cpu") print("Creating Dataloader") train_data_loader = OfRecordDataLoader( ofrecord_dir=args.ofrecord_path, mode="train", dataset_size=args.train_dataset_size, batch_size=args.train_batch_size, data_part_num=args.train_data_part, seq_length=args.seq_length, max_predictions_per_seq=args.max_predictions_per_seq, consistent=False, ) test_data_loader = OfRecordDataLoader( ofrecord_dir=args.ofrecord_path, mode="test", dataset_size=1024, batch_size=args.val_batch_size, data_part_num=4, seq_length=args.seq_length, max_predictions_per_seq=args.max_predictions_per_seq, consistent=False, ) print("Building BERT Model") hidden_size = 64 * args.num_attention_heads intermediate_size = 4 * hidden_size bert_model = BertForPreTraining( args.vocab_size, args.seq_length, hidden_size, args.num_hidden_layers, args.num_attention_heads, intermediate_size, nn.GELU(), args.hidden_dropout_prob, args.attention_probs_dropout_prob, args.max_position_embeddings, args.type_vocab_size, ) # Load the same initial parameters with lazy model. # from utils.compare_lazy_outputs import load_params_from_lazy # load_params_from_lazy( # bert_model.state_dict(), # "../../OneFlow-Benchmark/LanguageModeling/BERT/initial_model", # ) bert_model = bert_model.to(device) if args.use_ddp: bert_model = ddp(bert_model) optimizer = build_optimizer( args.optim_name, bert_model, args.lr, args.weight_decay, weight_decay_excludes=["bias", "LayerNorm", "layer_norm"], clip_grad_max_norm=1, clip_grad_norm_type=2.0, ) steps = args.epochs * len(train_data_loader) warmup_steps = int(steps * args.warmup_proportion) lr_scheduler = PolynomialLR(optimizer, steps=steps, end_learning_rate=0.0) lr_scheduler = flow.optim.lr_scheduler.WarmUpLR(lr_scheduler, warmup_factor=0, warmup_iters=warmup_steps, warmup_method="linear") ns_criterion = nn.CrossEntropyLoss(reduction="mean") mlm_criterion = nn.CrossEntropyLoss(reduction="none") def get_masked_lm_loss( logit_blob, masked_lm_positions, masked_lm_labels, label_weights, max_prediction_per_seq, ): # gather valid position indices logit_blob = flow.gather( logit_blob, index=masked_lm_positions.unsqueeze(2).repeat( 1, 1, args.vocab_size), dim=1, ) logit_blob = flow.reshape(logit_blob, [-1, args.vocab_size]) label_id_blob = flow.reshape(masked_lm_labels, [-1]) # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. pre_example_loss = mlm_criterion(logit_blob, label_id_blob) pre_example_loss = flow.reshape(pre_example_loss, [-1, max_prediction_per_seq]) numerator = flow.sum(pre_example_loss * label_weights) denominator = flow.sum(label_weights) + 1e-5 loss = numerator / denominator return loss train_total_losses = [] for epoch in range(args.epochs): metric = Metric( desc="bert pretrain", print_steps=args.loss_print_every_n_iters, batch_size=args.train_batch_size, keys=["total_loss", "mlm_loss", "nsp_loss", "pred_acc"], ) # Train bert_model.train() for step in range(len(train_data_loader)): bert_outputs = pretrain( train_data_loader, bert_model, ns_criterion, partial( get_masked_lm_loss, max_prediction_per_seq=args.max_predictions_per_seq, ), optimizer, lr_scheduler, ) if flow.env.get_rank() == 0: metric.metric_cb(step, epoch=epoch)(bert_outputs) train_total_losses.append(bert_outputs["total_loss"]) # Eval bert_model.eval() val_acc = validation(epoch, test_data_loader, bert_model, args.val_print_every_n_iters) save_model(bert_model, args.checkpoint_path, epoch, val_acc, False)
def main(): hidden_size = 64 * args.num_attention_heads # H = 64, size per head intermediate_size = hidden_size * 4 print("Create Bert model for SQuAD") squad_model = SQuAD( args.vocab_size, seq_length=args.seq_length, hidden_size=hidden_size, hidden_layers=args.num_hidden_layers, atten_heads=args.num_attention_heads, intermediate_size=intermediate_size, hidden_act=nn.GELU(), hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, max_position_embeddings=args.max_position_embeddings, type_vocab_size=args.type_vocab_size, initializer_range=0.02, ) # Load pretrain model from lazy trained model load_params_from_lazy(squad_model.state_dict(), args.model_load_dir) squad_model.to(device) if args.do_train: print("Create SQuAD training data decoders") train_decoders = SquadDecoder( args.train_data_dir, batch_size, args.train_data_part_num, args.seq_length ) optimizer = build_adamW_optimizer( squad_model, args.learning_rate, args.weight_decay_rate, weight_decay_excludes=["bias", "LayerNorm", "layer_norm"], ) lr_scheduler = PolynomialLR( optimizer, steps=args.iter_num, end_learning_rate=0.0 ) warmup_batches = int(args.iter_num * args.warmup_proportion) lr_scheduler = flow.optim.lr_scheduler.WarmUpLR( lr_scheduler, warmup_factor=0, warmup_iters=warmup_batches, warmup_method="linear", ) class SQuADGraph(nn.Graph): def __init__(self): super().__init__() self.squad_model = squad_model self.criterion = nn.CrossEntropyLoss() self.add_optimizer(optimizer, lr_sch=lr_scheduler) self._decoders = train_decoders if args.use_fp16: self.config.enable_amp(True) grad_scaler = flow.amp.GradScaler( init_scale=2 ** 30, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, ) self.set_grad_scaler(grad_scaler) def build(self): ( input_ids, input_mask, segment_ids, start_positions, end_positions, ) = self._decoders() input_ids = input_ids.to(device=device) input_mask = input_mask.to(device=device) segment_ids = segment_ids.to(device=device) start_positions = start_positions.to(device=device) end_positions = end_positions.to(device=device) start_logits, end_logits = self.squad_model( input_ids, segment_ids, input_mask ) start_logits = flow.reshape(start_logits, [-1, args.seq_length]) end_logits = flow.reshape(end_logits, [-1, args.seq_length]) start_loss = self.criterion(start_logits, start_positions.squeeze(1)) end_loss = self.criterion(end_logits, end_positions.squeeze(1)) total_loss = (start_loss + end_loss) * 0.5 total_loss.backward() return total_loss squad_graph = SQuADGraph() for epoch in range(args.num_epochs): squad_model.train() metric = Metric( desc="train", print_steps=args.loss_print_every_n_iter, batch_size=batch_size, keys=["total_loss"], ) for step in range(epoch_size): metric.metric_cb(step, epoch=epoch)(squad_finetune(squad_graph)) if args.save_last_snapshot: save_model(squad_model, args.model_save_dir, "last_snapshot") if args.do_eval: assert os.path.isdir(args.eval_data_dir) print("Create SQuAD testing data decoders") test_decoders = SquadDecoder( args.eval_data_dir, eval_batch_size, args.eval_data_part_num, args.seq_length, is_train=False, ) squad_model.eval() class SQuADEvalGraph(nn.Graph): def __init__(self): super().__init__() self.squad_model = squad_model self._decoders = test_decoders def build(self): (input_ids, input_mask, segment_ids, unique_ids) = self._decoders() input_ids = input_ids.to(device=device) input_mask = input_mask.to(device=device) segment_ids = segment_ids.to(device=device) unique_ids = unique_ids.to(device=device) with flow.no_grad(): start_logits, end_logits = self.squad_model( input_ids, segment_ids, input_mask ) return unique_ids, start_logits, end_logits squad_eval_graph = SQuADEvalGraph() squad_eval(num_eval_steps, squad_eval_graph, args.loss_print_every_n_iter)