def train_batch(self, batch, model, optimizer, is_mp): scaler = paddle.amp.GradScaler(init_loss_scaling=5160) if is_mp: scaler = fleet.distributed_scaler(scaler) with paddle.amp.auto_cast(enable=True, level="O2"): output = model(batch) loss = output.mean() scaled = scaler.scale(loss) scaled.backward() scaler.step(optimizer) scaler.update() optimizer.clear_grad() return scaled
def train_batch(self, batch, model, optimizer, is_mp): scaler = paddle.amp.GradScaler(init_loss_scaling=5160) if is_mp: scaler = fleet.distributed_scaler(scaler) with paddle.amp.auto_cast(): output = model(batch) loss = output.mean() scaled = scaler.scale(loss) # scale the loss scaled.backward() # do backward scaler.step(optimizer) # update parameters scaler.update() optimizer.clear_grad() return scaled
def test_pp_model(self): hcg = fleet.get_hybrid_communicate_group() word_size = hcg.get_model_parallel_world_size() dp_id = hcg.get_data_parallel_rank() pp_id = hcg.get_stage_id() rank_id = dist.get_rank() set_random_seed(1024, dp_id, rank_id) grad_clip = paddle.nn.ClipGradByGlobalNorm(1.0) #construct model a model_a = AlexNet(10) scheduler_a = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2], values=[0.001, 0.002], verbose=True) optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a, grad_clip=grad_clip, parameters=model_a.parameters()) scaler_a = paddle.amp.GradScaler(init_loss_scaling=2**5) param_len = len(model_a.parameters()) parameters = [] for param in model_a.parameters(): parameters.append(param.numpy()) # construct model b model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size) scheduler_b = paddle.optimizer.lr.PiecewiseDecay(boundaries=[2], values=[0.001, 0.002], verbose=True) optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b, grad_clip=grad_clip, parameters=model_b.parameters()) model_b = fleet.distributed_model(model_b) optimizer_b = fleet.distributed_optimizer(optimizer_b) scaler_b = paddle.amp.GradScaler(init_loss_scaling=2**5) scaler_b = fleet.distributed_scaler(scaler_b) for idx, param in enumerate(model_b.parameters()): param.set_value(parameters[idx + pp_id * (param_len // 2)]) # construct reader train_reader = paddle.batch(paddle.dataset.mnist.train(), batch_size=batch_size, drop_last=True) for step_id, data in enumerate(train_reader()): x_data = np.array([x[0] for x in data]).astype('float32').reshape( batch_size, 1, 28, 28) y_data = np.array([x[1] for x in data ]).astype('int64').reshape(batch_size, 1) img = paddle.to_tensor(x_data) label = paddle.to_tensor(y_data) img.stop_gradient = True label.stop_gradient = True if step_id >= 5: return True with paddle.amp.auto_cast(): loss_a = model_a(img, label) scaler_a.scale(loss_a).backward() scaler_a.minimize(optimizer_a, loss_a) optimizer_a.clear_grad() scheduler_a.step() with paddle.amp.auto_cast(): loss_b = model_b.train_batch([img, label], optimizer_b, scheduler_b, scaler=scaler_b) print("loss: ", loss_a.numpy(), loss_b.numpy()) np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy(), rtol=5e-5)
def do_train(args): paddle.set_device(args.device) nranks = paddle.distributed.get_world_size() strategy = fleet.DistributedStrategy() strategy.hybrid_configs = { "dp_degree": args.dp_degree, "mp_degree": args.mp_degree, "pp_degree": args.pp_degree, "sharding_degree": args.sharding_degree } accumulate_steps = args.local_batch_size // args.micro_batch_size strategy.pipeline_configs = { "accumulate_steps": accumulate_steps, "micro_batch_size": args.micro_batch_size } # set control in tensor parallel strategy.tensor_parallel_configs = {"tensor_init_seed": args.seed} fleet.init(is_collective=True, strategy=strategy) # obtain rank message of hybrid parallel hcg = fleet.get_hybrid_communicate_group() global_rank = hcg.get_global_rank() mp_rank = hcg.get_model_parallel_rank() pp_rank = hcg.get_stage_id() dp_rank = hcg.get_data_parallel_rank() sharding_rank = hcg.get_sharding_parallel_rank() # sharding stage2/3 not support hybrid parallel if args.sharding_stage in [2, 3]: assert args.dp_degree == args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support hybrid parallel later" sharding_size = hcg.get_sharding_parallel_world_size() data_world_rank = dp_rank * sharding_size + sharding_rank data_world_size = args.dp_degree * args.sharding_degree local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) # seed control in hybrid parallel set_hyrbid_parallel_seed(args.seed, data_world_rank, mp_rank, pp_rank) default_global_tokens_num = args.global_batch_size * args.max_seq_len model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) # Define log writer log_writer_path = os.path.join( args.output_dir, "train_log", "{}_globalbsz_{}_pure_fp16_{}_recompute_{}_card_{}".format( args.model_name_or_path, args.global_batch_size, args.use_pure_fp16, False, global_rank).lower()) if os.path.exists(log_writer_path): import shutil shutil.rmtree(log_writer_path) log_writer = LogWriter(log_writer_path) pretrained_models_list = list( model_class.pretrained_init_configuration.keys()) if args.model_name_or_path in pretrained_models_list: model_config = model_class.pretrained_init_configuration[ args.model_name_or_path] model_config["hidden_dropout_prob"] = args.hidden_dropout_prob model_config[ "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob model_config['num_partitions'] = args.mp_degree model_config['use_recompute'] = args.use_recompute if args.pp_degree == 1: model = GPTForPretraining(GPTModel(**model_config)) else: model_config['topology'] = hcg.topology() model = GPTForPretrainingPipe(**model_config) else: model = GPTForPretraining.from_pretrained( args.model_name_or_path, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob) # Create the critrion for the gpt model criterion = GPTPretrainingCriterion() if args.decay_steps is None: args.decay_steps = args.max_steps warmup_step = args.warmup_rate * args.decay_steps lr_scheduler = None if args.lr_decay_style == "none": lr_scheduler = None elif args.lr_decay_style == "cosine": lr_scheduler = lr.CosineAnnealingWithWarmupDecay( max_lr=args.max_lr, min_lr=args.min_lr, warmup_step=warmup_step, decay_step=args.decay_steps) clip = None if args.grad_clip > 0: clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.grad_clip) # Generate parameter names needed to perform weight decay. # All bias and LayerNorm parameters are excluded. decay_params = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] if args.sharding_stage == 1 and args.sharding_degree > 1: optimizer = DygraphShardingOptimizer( hcg=fleet.get_hybrid_communicate_group(), user_defined_strategy=strategy, params=model.parameters(), inner_optimizer_class=paddle.optimizer.AdamW, learning_rate=lr_scheduler if lr_scheduler is not None else args.max_lr, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon, weight_decay=args.weight_decay, grad_clip=clip, apply_decay_param_fun=lambda x: x in decay_params) else: optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler if lr_scheduler is not None else args.max_lr, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon, parameters=model.parameters(), weight_decay=args.weight_decay, grad_clip=clip, apply_decay_param_fun=lambda x: x in decay_params, # TODO: remove 'multi_precision' in definition of optimizer # and add it to 'paddle.amp.decorate' multi_precision=args.use_pure_fp16) if args.use_pure_fp16: scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss) # level O2 means converting the network to FP16 if args.sharding_stage not in [2, 3]: scaler = fleet.distributed_scaler(scaler) model = paddle.amp.decorate( models=model, level='O2', save_dtype='float32') # wrap sharding stage2/3 and add collective group # TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature if args.sharding_stage in [2, 3]: scaler = scaler if args.use_pure_fp16 else None model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler, args.sharding_offload) elif paddle.distributed.get_world_size() > 1: model = fleet.distributed_model(model) optimizer = fleet.distributed_optimizer(optimizer) if args.model_name_or_path not in pretrained_models_list: logger.info("Try to load checkpoint from %s " % args.model_name_or_path) opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt") if os.path.exists(opt_path): opt_dict = paddle.load(opt_path) optimizer.set_state_dict(opt_dict) else: logger.warning("No optimizer checkpoint file found in %s." % opt_path) global_step = 0 tic_train = time.time() for epoch in range(args.num_train_epochs): files = get_train_data_file(args) files.sort() num_files = len(files) for f_id in range(num_files): data_file = files[f_id] train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset( args, [data_file], local_rank=local_rank, data_world_size=data_world_size, data_world_rank=data_world_rank, eos_id=tokenizer.eos_token_id) # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader # many times. and start a new random dataloader. valid_data_loader = valid_data_loader() test_data_loader = test_data_loader() # time count train_reader_cost = 0.0 train_run_cost = 0.0 reader_start = time.time() for step, batch in enumerate(train_data_loader()): train_reader_cost += time.time() - reader_start train_start = time.time() global_step += 1 tokens, loss_mask, position_ids, labels = batch loss_mask.stop_gradient = True labels.stop_gradient = True position_ids.stop_gradient = True if args.pp_degree == 1: # In ParallelMode of DataParallel, 'no_sync' can be used for improving # performance of model by gradient accumulation. loss = 0.0 for i in range(accumulate_steps): start_index = i * args.micro_batch_size end_index = start_index + args.micro_batch_size with paddle.amp.auto_cast( args.use_pure_fp16, custom_black_list=[ "reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div" ], level='O2'): preds = model( tokens[start_index:end_index, :], position_ids[start_index:end_index, :]) loss_mbs = criterion( preds, labels[start_index:end_index, :], loss_mask[start_index:end_index, :]) loss_mbs = loss_mbs / accumulate_steps if args.use_pure_fp16: scaler.scale(loss_mbs).backward() else: loss_mbs.backward() loss = loss + loss_mbs if args.use_pure_fp16: if args.sharding_stage in [2, 3]: scaler.step(optimizer) scaler.update() else: scaler.minimize(optimizer, loss) else: optimizer.step() if lr_scheduler is not None: lr_scheduler.step() optimizer.clear_grad() else: data = [(tokens, position_ids), (labels, loss_mask)] with paddle.amp.auto_cast( args.use_pure_fp16, custom_black_list=[ "reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div" ], level='O2'): loss = model.train_batch( data, optimizer=optimizer, lr_scheduler=lr_scheduler, scaler=scaler if args.use_pure_fp16 else None) # Sync for profile time, delete it may be a little faster paddle.device.cuda.synchronize() train_run_cost += time.time() - train_start # Profile for model benchmark profiler.add_profiler_step(args.profiler_options) if global_step % args.logging_freq == 0: avg_loss = loss.numpy() speed = args.logging_freq / ( train_reader_cost + train_run_cost) avg_reader_cost = train_reader_cost / args.logging_freq logger.info( "global step %d, epoch: %d, batch: %d, loss: %.9f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, speed: %.2f step/s, ips: %.0f tokens/s, ips_per_card: %.0f tokens/s, learning rate: %.5e" % (global_step, epoch, step, avg_loss, avg_reader_cost, 1. / speed, speed, speed * default_global_tokens_num, speed * default_global_tokens_num / nranks, optimizer.get_lr())) log_writer.add_scalar("loss", float(loss), global_step) log_writer.add_scalar("learning_rate", optimizer.get_lr(), global_step) tic_train = time.time() train_reader_cost = 0.0 train_run_cost = 0.0 if args.check_accuracy: if global_step >= args.max_steps: return else: continue if global_step % args.eval_freq == 0: # Since the valid data broardcast to all devices, we do evaluate on all device. run_evaluate(args, valid_data_loader, model, criterion, args.eval_iters, log_writer, global_step, epoch, "valid") # TODO: 1. merge paramters while saving model. 2. ensure that the model is saved and loaded correctly # only dp_rank = 0 save model if (global_step % args.save_steps == 0 or global_step >= args.max_steps) and dp_rank == 0: model_to_save = model._layers if paddle.distributed.get_world_size( ) > 1 and args.sharding_stage not in [2, 3] else model output_dir = os.path.join(args.output_dir, "step_%d" % global_step) os.makedirs(output_dir, exist_ok=True) logger.info("Save model to %s" % output_dir) if args.pp_degree > 1: if mp_rank == 0 and sharding_rank == 0 and pp_rank == 0: tokenizer.save_pretrained(output_dir) model_to_save.save_state_dict(output_dir) paddle.save( optimizer.state_dict(), os.path.join( output_dir, "model_state_mp_{:0>2d}_sharding_{:0>2d}_pp_{:0>2d}.pdopt". format(mp_rank, sharding_rank, pp_rank))) else: if args.sharding_stage == 3: # If parameter need to convert to cpu, please add convert2cpu=True model_to_save.get_all_parameters(convert2cpu=False) if mp_rank == 0 and sharding_rank == 0: tokenizer.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir) paddle.save( optimizer.state_dict(), os.path.join( output_dir, "model_state_mp_{:0>2d}_sharding_{:0>2d}.pdopt". format(mp_rank, sharding_rank))) if global_step >= args.max_steps: run_evaluate(args, test_data_loader, model, criterion, args.test_iters, log_writer, global_step, epoch, "test") logger.info("The training process is complete.") del train_data_loader return reader_start = time.time() del train_data_loader
def do_train(args): paddle.set_device(args.device) worker_index = paddle.distributed.get_rank() worker_num = paddle.distributed.get_world_size() local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) if worker_num > 1: paddle.distributed.init_parallel_env() if args.dp_degree * args.sharding_degree == 1: args.dp_degree = worker_num args.sharding_degree = 1 args_post_process(args, worker_num) logger.info('{:20}:{}'.format("paddle commit id", paddle.version.commit)) for arg in vars(args): logger.info('{:20}:{}'.format(arg, getattr(args, arg))) strategy = fleet.DistributedStrategy() strategy.hybrid_configs = { "dp_degree": args.dp_degree, "mp_degree": 1, "pp_degree": 1, "sharding_degree": 1 } fleet.init(is_collective=True, strategy=strategy) hcg = fleet.get_hybrid_communicate_group() # Create the random seed for the worker set_seed(args) assert args.dp_degree * args.sharding_degree == worker_num, \ "The product of degree num should be equal to worker_num." # Create log write, log_writer = None if worker_index == 0: log_writer = LogWriter(os.path.join(args.output_dir, default_logdir())) # Define the input data in the static mode base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[ args.model_type] pretrained_models_list = list( model_class.pretrained_init_configuration.keys()) # load config in checkpoint global_step = 0 consumed_samples = 0 checkpoint_dir = os.path.join(args.output_dir, "model_last") if os.path.exists(checkpoint_dir): if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")): with open(os.path.join(checkpoint_dir, "./config.yml"), "r") as f: step_config = yaml.load(f, Loader=yaml.FullLoader) assert step_config[ "global_batch_size"] == args.global_batch_size, "Please ensure checkpoint global batch size is the same. Folder: {}".format( checkpoint_dir) consumed_samples = step_config["consumed_samples"] global_step = step_config["global_step"] if args.model_name_or_path in pretrained_models_list: model_config = model_class.pretrained_init_configuration[ args.model_name_or_path] model_config["hidden_dropout_prob"] = args.hidden_dropout_prob model_config[ "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob model = model_class(base_class(**model_config)) else: model = model_class.from_pretrained( args.model_name_or_path, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob) criterion = criterion_class() if worker_index == 0: # log the model config and args model_config_json = json.dumps(model.get_model_config(), ensure_ascii=False, indent=2) log_writer.add_text("model_config", model_config_json) args_dict = {"paddle commit id": str(paddle.version.commit)} for arg in vars(args): args_dict[arg] = str(getattr(args, arg)) log_writer.add_text("args", json.dumps(args_dict, indent=2)) # Create the learning_rate sheduler and optimizer if args.decay_steps is None: args.decay_steps = args.max_steps assert args.warmup_rate <= 1.0 and args.warmup_rate >= 0.0, "warmup_rate should be in [0, 1]" args.warmup_steps = args.warmup_rate * args.max_steps lr_scheduler = LinearAnnealingWithWarmupDecay( args.max_lr, args.min_lr, warmup_step=args.warmup_steps, decay_step=args.decay_steps, last_epoch=global_step) clip = None if args.grad_clip > 0: clip = paddle.fluid.clip.GradientClipByGlobalNorm( clip_norm=args.grad_clip) decay_param = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] logger.info("Using paddle.optimizer.AdamW.") optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler if lr_scheduler is not None else args.max_lr, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon, parameters=model.parameters(), weight_decay=args.weight_decay, grad_clip=clip, apply_decay_param_fun=lambda x: x in decay_param, multi_precision=args.use_amp) if args.use_amp: scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss) scaler = fleet.distributed_scaler(scaler) model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') if paddle.distributed.get_world_size() > 1: model = fleet.distributed_model(model) optimizer = fleet.distributed_optimizer(optimizer) tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) data_file = get_train_data_file(args) train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset( args, data_file, tokenizer, data_world_size=worker_num, data_world_rank=worker_index, max_seq_len=args.max_seq_len, current_step=global_step) # load checkpoint vars if os.path.exists(checkpoint_dir): if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")): logger.info("Try to load checkpoint from %s " % checkpoint_dir) opt_path = os.path.join(checkpoint_dir, "model_state.pdopt") params_path = os.path.join(checkpoint_dir, "model_state.pdparams") if os.path.exists(opt_path): opt_dict = paddle.load(opt_path) optimizer.set_state_dict(opt_dict) model_dict = paddle.load(params_path) model.set_state_dict(model_dict) else: logger.warning("No optimizer checkpoint file found in %s." % opt_path) logger.info( "Checkpoint loaded from global step: {}".format(global_step)) loss_global = { "loss": paddle.to_tensor(0.0), "lm_loss": paddle.to_tensor(0.0), "sop_loss": paddle.to_tensor(0.0), } tic_train = time.time() while True: # If not call valid_data_loader, the enumerate will call valid_data_loader # many times. and start a new random dataloader. valid_data_loader = valid_data_loader() test_data_loader = test_data_loader() # time count train_reader_cost = 0.0 train_run_cost = 0.0 reader_start = time.time() for step, batch in enumerate(train_data_loader()): train_reader_cost += time.time() - reader_start train_start = time.time() # 0. input_ids, # 1. segment_ids, # 2. input_mask, # 3. masked_lm_positions, # 4. masked_lm_labels, # 5. next_sentence_labels input_ids, segment_ids, input_mask, masked_lm_positions, \ masked_lm_labels, next_sentence_labels = batch with paddle.amp.auto_cast(args.use_amp, custom_black_list=[ "reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div" ], level='O2'): # Create the model for the ernie pretrain prediction_scores, seq_relationship_score = model( input_ids=input_ids, token_type_ids=segment_ids, position_ids=None, attention_mask=input_mask, masked_positions=masked_lm_positions) lm_loss, sop_loss = criterion(prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels) loss = lm_loss + sop_loss if args.use_amp: scaler.scale(loss).backward() scaler.minimize(optimizer, loss) else: loss.backward() optimizer.step() optimizer.clear_grad() train_run_cost += time.time() - train_start # Skip for accumulate_steps in global step if (step + 1) % args.accumulate_steps != 0: continue global_step += 1 loss_global["loss"] += loss.detach() loss_global["lm_loss"] += lm_loss.detach() loss_global["sop_loss"] += sop_loss.detach() if global_step % args.logging_freq == 0: log_info_dict = dict() log_info_dict["global_step"] = global_step for k, v in loss_global.items(): log_info_dict[k] = all_gather(v) / args.logging_freq v.subtract_(v) if worker_index == 0: speed = args.logging_freq / (time.time() - tic_train) log_info_dict["learning_rate"] = lr_scheduler.get_lr() log_info_dict["steps_per_second"] = speed log_info_dict[ "samples_per_second"] = speed * args.global_batch_size for k, v in log_info_dict.items(): log_writer.add_scalar("train/%s" % k, v, global_step) common_loginfo = "global step %d, loss: %.9f, lm_loss: %.6f, sop_loss: %.6f, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % ( global_step, log_info_dict["loss"], log_info_dict["lm_loss"], log_info_dict["sop_loss"], speed, log_info_dict["samples_per_second"], log_info_dict["learning_rate"]) addition_info = "" if args.use_amp: amp_info = { "loss_scaling": scaler._scale.item(), "incr_count": scaler._incr_count, "decr_count": scaler._decr_count } addition_info = ", ".join("%s: %d" % (k, v) for k, v in amp_info.items()) addition_info = " " + addition_info for k, v in amp_info.items(): log_writer.add_scalar("amp/%s" % k, v, global_step) logger.info(common_loginfo + addition_info) tic_train = time.time() if lr_scheduler is not None: lr_scheduler.step() if global_step % args.eval_freq == 0: # TODO, check the input data of validation run_evaluate(valid_data_loader, model, criterion, args.eval_iters, log_writer, global_step, args, task_name="valid") tic_train = time.time() def save_ckpt(output_dir, model, tokenizer, args, global_step): step_config = { "model_name": args.model_name_or_path, "global_step": global_step, "global_batch_size": args.global_batch_size, "consumed_samples": global_step * args.global_batch_size, } logger.debug("saving models to {}".format(output_dir)) model_to_save = model._layers if isinstance( model, paddle.DataParallel) else model model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) paddle.save(optimizer.state_dict(), os.path.join(output_dir, "model_state.pdopt")) with open(os.path.join(output_dir, "config.yml"), "w") as f: yaml.dump(step_config, f, encoding='utf-8', allow_unicode=True) if global_step % args.save_steps == 0 or global_step >= args.max_steps: output_dir = os.path.join(args.output_dir, "model_%d" % global_step) if worker_index == 0: save_ckpt(output_dir, model, tokenizer, args, global_step) if worker_num > 1: paddle.distributed.barrier() tic_train = time.time() if global_step % args.checkpoint_steps == 0: output_dir = os.path.join(args.output_dir, "model_last") if worker_index == 0: if not os.path.exists(output_dir): os.mkdir(output_dir) output_dir_bak = os.path.join(args.output_dir, "model_last_bak") if os.path.exists(output_dir): if os.path.exists(output_dir_bak): shutil.rmtree(output_dir_bak) shutil.move(output_dir, output_dir_bak) os.mkdir(output_dir) save_ckpt(output_dir, model, tokenizer, args, global_step) if worker_num > 1: paddle.distributed.barrier() if global_step >= args.max_steps: run_evaluate(test_data_loader, model, criterion, args.test_iters, log_writer, global_step, args, task_name="test") del train_data_loader return
def do_train(args): paddle.set_device(args.device) strategy = fleet.DistributedStrategy() strategy.hybrid_configs = { "dp_degree": args.dp_degree, "mp_degree": args.mp_degree, "pp_degree": args.pp_degree } strategy.pipeline_configs = { "accumulate_steps": args.local_batch_size // args.micro_batch_size, "micro_batch_size": args.micro_batch_size } fleet.init(is_collective=True, strategy=strategy) # obtain rank message of hybrid parallel hcg = fleet.get_hybrid_communicate_group() global_rank = hcg.get_global_rank() mp_rank = hcg.get_model_parallel_rank() pp_rank = hcg.get_stage_id() dp_rank = hcg.get_data_parallel_rank() local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) # seed control in hybrid parallel set_hyrbid_parallel_seed(args.seed, dp_rank, mp_rank, pp_rank) default_global_tokens_num = args.global_batch_size * args.max_seq_len model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) # Define log writer log_writer_path = os.path.join( args.output_dir, "train_log", "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format( args.model_name_or_path, args.global_batch_size, args.use_amp, False, global_rank).lower()) if os.path.exists(log_writer_path): import shutil shutil.rmtree(log_writer_path) log_writer = LogWriter(log_writer_path) pretrained_models_list = list( model_class.pretrained_init_configuration.keys()) if args.model_name_or_path in pretrained_models_list: model_config = model_class.pretrained_init_configuration[ args.model_name_or_path] model_config["hidden_dropout_prob"] = args.hidden_dropout_prob model_config[ "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob model_config['num_partitions'] = args.mp_degree if args.pp_degree == 1: model = GPTForPretraining(GPTModel(**model_config)) else: model_config['topology'] = hcg.topology() model_config["recompute_interval"] = 1 if args.use_recompute else 0 model = GPTForPretrainingPipe(**model_config) else: model = GPTForPretraining.from_pretrained( args.model_name_or_path, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob) # Create the critrion for the gpt model criterion = GPTPretrainingCriterion() if args.decay_steps is None: args.decay_steps = args.max_steps warmup_step = args.warmup_rate * args.decay_steps lr_scheduler = None if args.lr_decay_style == "none": lr_scheduler = None elif args.lr_decay_style == "cosine": lr_scheduler = lr.CosineAnnealingWithWarmupDecay( max_lr=args.max_lr, min_lr=args.min_lr, warmup_step=warmup_step, decay_step=args.decay_steps) clip = None if args.grad_clip > 0: clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.grad_clip) # Generate parameter names needed to perform weight decay. # All bias and LayerNorm parameters are excluded. decay_params = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler if lr_scheduler is not None else args.max_lr, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon, parameters=model.parameters(), weight_decay=args.weight_decay, grad_clip=clip, apply_decay_param_fun=lambda x: x in decay_params) if paddle.distributed.get_world_size() > 1: model = fleet.distributed_model(model) optimizer = fleet.distributed_optimizer(optimizer) if args.use_amp: scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss) scaler = fleet.distributed_scaler(scaler) if args.model_name_or_path not in pretrained_models_list: logger.info("Try to load checkpoint from %s " % args.model_name_or_path) opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt") if os.path.exists(opt_path): opt_dict = paddle.load(opt_path) optimizer.set_state_dict(opt_dict) else: logger.warning("No optimizer checkpoint file found in %s." % opt_path) global_step = 0 tic_train = time.time() for epoch in range(args.num_train_epochs): files = [ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if (os.path.isfile(os.path.join(args.input_dir, f)) and "npz_" not in str(f)) ] files.sort() num_files = len(files) for f_id in range(num_files): data_file = files[f_id] train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset( args, data_file, local_rank=local_rank, data_world_size=args.dp_degree, data_world_rank=dp_rank, eos_id=tokenizer.eos_token_id) # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader # many times. and start a new random dataloader. valid_data_loader = valid_data_loader() test_data_loader = test_data_loader() for step, batch in enumerate(train_data_loader()): global_step += 1 tokens, loss_mask, labels = batch loss_mask.stop_gradient = True labels.stop_gradient = True if args.pp_degree == 1: with paddle.amp.auto_cast( args.use_amp, custom_white_list=[ "layer_norm", "softmax", "gelu" ], custom_black_list=[ "reduce_sum", "c_softmax_with_cross_entropy", "c_embedding" ]): preds = model(tokens) loss = criterion(preds, labels, loss_mask) if args.use_amp: scaler.scale(loss).backward() scaler.minimize(optimizer, loss) else: loss.backward() optimizer.step() if lr_scheduler is not None: lr_scheduler.step() optimizer.clear_grad() else: data = [tokens, (labels, loss_mask)] with paddle.amp.auto_cast( args.use_amp, custom_white_list=[ "layer_norm", "softmax", "gelu" ], custom_black_list=[ "reduce_sum", "c_softmax_with_cross_entropy", "c_embedding" ]): loss = model.train_batch( data, optimizer=optimizer, lr_scheduler=lr_scheduler, scaler=scaler if args.use_amp else None) if global_step % args.logging_freq == 0: avg_loss = loss.numpy() speed = args.logging_freq / (time.time() - tic_train) logger.info( "global step %d, epoch: %d, batch: %d, loss: %.9f, speed: %.2f step/s, ips: %.0f tokens/s, learning rate: %.5e" % (global_step, epoch, step, avg_loss, speed, speed * default_global_tokens_num, optimizer.get_lr())) log_writer.add_scalar("loss", float(loss), global_step) log_writer.add_scalar("learning_rate", optimizer.get_lr(), global_step) tic_train = time.time() if args.check_accuracy: if global_step >= args.max_steps: return else: continue if global_step % args.eval_freq == 0: # Since the valid data broardcast to all devices, we do evaluate on all device. run_evaluate(args, valid_data_loader, model, criterion, args.eval_iters, log_writer, global_step, epoch, "valid") # only dp_rank = 0 save model if (global_step % args.save_steps == 0 or global_step >= args.max_steps) and dp_rank == 0: model_to_save = model._layers if paddle.distributed.get_world_size( ) > 1 else model output_dir = os.path.join(args.output_dir, "step_%d" % global_step) os.makedirs(output_dir, exist_ok=True) logger.info("Save model to %s" % output_dir) if args.pp_degree > 1: model_to_save.save_state_dict(output_dir) if mp_rank * pp_rank == 1: tokenizer.save_pretrained(output_dir) paddle.save( optimizer.state_dict(), os.path.join( output_dir, "model_state_mp_{:0>2d}_pp_{:0>2d}.pdopt". format(mp_rank, pp_rank))) else: path = os.path.join(output_dir, 'model_{:0>2d}'.format(mp_rank)) os.makedirs(path, exist_ok=True) model_to_save.save_pretrained(path) paddle.save(optimizer.state_dict(), os.path.join(path, "model_state.pdopt")) tokenizer.save_pretrained(path) if global_step >= args.max_steps: run_evaluate(args, test_data_loader, model, criterion, args.test_iters, log_writer, global_step, epoch, "test") logger.info("The training process is complete.") del train_data_loader return del train_data_loader
def do_train(args): paddle.set_device(args.device) strategy = fleet.DistributedStrategy() strategy.hybrid_configs = { "dp_degree": args.dp_degree, "mp_degree": args.mp_degree, "pp_degree": args.pp_degree } accumulate_steps = args.local_batch_size // args.micro_batch_size strategy.pipeline_configs = { "accumulate_steps": accumulate_steps, "micro_batch_size": args.micro_batch_size } fleet.init(is_collective=True, strategy=strategy) # obtain rank message of hybrid parallel hcg = fleet.get_hybrid_communicate_group() global_rank = hcg.get_global_rank() mp_rank = hcg.get_model_parallel_rank() pp_rank = hcg.get_stage_id() dp_rank = hcg.get_data_parallel_rank() local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) # seed control in hybrid parallel set_hyrbid_parallel_seed(args.seed, dp_rank, mp_rank, pp_rank) default_global_tokens_num = args.global_batch_size * args.max_seq_len model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) # Define log writer log_writer_path = os.path.join( args.output_dir, "train_log", "{}_globalbsz_{}_pure_fp16_{}_recompute_{}_card_{}".format( args.model_name_or_path, args.global_batch_size, args.use_pure_fp16, False, global_rank).lower()) if os.path.exists(log_writer_path): import shutil shutil.rmtree(log_writer_path) log_writer = LogWriter(log_writer_path) pretrained_models_list = list( model_class.pretrained_init_configuration.keys()) if args.model_name_or_path in pretrained_models_list: model_config = model_class.pretrained_init_configuration[ args.model_name_or_path] model_config["hidden_dropout_prob"] = args.hidden_dropout_prob model_config[ "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob model_config['num_partitions'] = args.mp_degree # MOE config initialize_model_and_expert_group(hcg) model_config['expert_mode'] = args.expert_mode model_config['hcg'] = hcg model_config['num_experts'] = args.num_experts model_config['top_k'] = args.top_k if args.expert_mode: model_config['gate'] = args.gate if args.pp_degree == 1: model_config["recompute_interval"] = 1 if args.use_recompute else 0 model_config["recompute_partition"] = args.recompute_partition model_config["recompute_offload"] = args.recompute_offload if args.use_recompute and args.recompute_partition: raise Exception( "when use_recompute is True, recompute_partition must be False in MoE." ) model = GPTForPretraining(GPTModel(**model_config)) else: model_config['topology'] = hcg.topology() model_config["recompute_interval"] = 1 if args.use_recompute else 0 model = GPTForPretrainingPipe(**model_config) else: model = GPTForPretraining.from_pretrained( args.model_name_or_path, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob) # Create the critrion for the gpt model criterion = GPTPretrainingCriterion() if args.decay_steps is None: args.decay_steps = args.max_steps warmup_step = args.warmup_rate * args.decay_steps lr_scheduler = None if args.lr_decay_style == "none": lr_scheduler = None elif args.lr_decay_style == "cosine": lr_scheduler = lr.CosineAnnealingWithWarmupDecay( max_lr=args.max_lr, min_lr=args.min_lr, warmup_step=warmup_step, decay_step=args.decay_steps) # Generate parameter names needed to perform weight decay. # All bias and LayerNorm parameters are excluded. if args.use_pure_fp16: scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss) scaler = fleet.distributed_scaler(scaler) scaler._unscale = MethodType(unscale_method, scaler) model = paddle.amp.decorate(models=model, optimizers=None, level='O2', save_dtype='float32') opt_fused_tensors, decay_fused_tensors, reduce_fused_tensors, gate_fused_tensors, \ expert_fusion_names = parameters_classify(model) decay_params = [p.name for p in decay_fused_tensors] clip = None if args.grad_clip > 0: is_expert_param_fun = lambda param: param.name in expert_fusion_names clip = moe.ClipGradByGlobalNorm(clip_norm=args.grad_clip, \ is_expert_param_func = is_expert_param_fun, \ moe_group = hcg.get_expert_parallel_group()) optimizer = AdamW( learning_rate=lr_scheduler if lr_scheduler is not None else args.max_lr, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon, parameters=opt_fused_tensors, weight_decay=args.weight_decay, grad_clip=clip, apply_decay_param_fun=lambda x: x in decay_params, #decay_params, multi_precision=args.use_pure_fp16) if paddle.distributed.get_world_size() > 1 and args.resume_dir is None: print(">> initialize....") initialize_mp_dp_parameters(model, hcg) #in order to restore reader. pass_num = 0 file_id = 0 start_epoch = 0 args.resume_dir = None if len(args.resume_dir) <= 0 else args.resume_dir if args.resume_dir is not None: global_step, loss_scale, data_meta = load_checkpoint( args, model, optimizer, lr_scheduler, tokenizer, dp_rank, mp_rank, pp_rank) pass_num = data_meta["pass_num"] file_id = data_meta["file_id"] start_epoch = data_meta["start_epoch"] if args.use_pure_fp16: scaler = paddle.amp.GradScaler( init_loss_scaling=loss_scale if args. resume_dir is not None else args.scale_loss) scaler = fleet.distributed_scaler(scaler) scaler._unscale = MethodType(unscale_method, scaler) model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2', save_dtype='float32') if args.model_name_or_path not in pretrained_models_list: logger.info("Try to load checkpoint from %s " % args.model_name_or_path) opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt") if os.path.exists(opt_path): opt_dict = paddle.load(opt_path) optimizer.set_state_dict(opt_dict) else: logger.warning("No optimizer checkpoint file found in %s." % opt_path) global_step = 0 if args.resume_dir is None else global_step timers = get_timers() tic_train = time.time() for epoch in range(start_epoch, args.num_train_epochs): files = [ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if (os.path.isfile(os.path.join(args.input_dir, f)) and "npz_" not in str(f)) ] files.sort() num_files = len(files) for f_id in range(file_id, num_files): data_file = files[f_id] train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset( args, data_file, local_rank=local_rank, data_world_size=args.dp_degree, data_world_rank=dp_rank, eos_id=tokenizer.eos_token_id) # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader # many times. and start a new random dataloader. valid_data_loader = valid_data_loader() test_data_loader = test_data_loader() for step, batch in enumerate(train_data_loader()): # to remove the train data that has been studyed. if step < global_step - pass_num: continue global_step += 1 tokens, loss_mask, labels = batch loss_mask.stop_gradient = True labels.stop_gradient = True loss = 0.0 for i in range(accumulate_steps): start_index = i * args.micro_batch_size end_index = start_index + args.micro_batch_size timers('forward-compute').start() with paddle.amp.auto_cast( args.use_pure_fp16, custom_black_list=[ "reduce_sum", "c_softmax_with_cross_entropy", "elementwise_div", ], level='O2'): preds = model(tokens[start_index:end_index, :]) loss_mbs = criterion( preds, labels[start_index:end_index, :], loss_mask[start_index:end_index, :]) timers('forward-compute').stop() if args.gate != "naive" and args.balance_loss_weight: aux_loss_list = [ l.moe_mlp.gate.get_loss(clear=False) for l in model.gpt.decoder.layers if hasattr(l.moe_mlp, "gate") ] bal_loss = paddle.concat(aux_loss_list) if bal_loss.dtype == paddle.float16: bal_loss = paddle.cast(bal_loss, dtype=paddle.float32) bal_loss = bal_loss.mean() loss_mbs += bal_loss * args.balance_loss_weight loss_mbs = loss_mbs / accumulate_steps timers('backward-compute').start() if args.use_pure_fp16: scaler.scale(loss_mbs).backward() else: loss_mbs.backward() timers('backward-compute').stop() loss = loss + loss_mbs timers('backward-params-all-reduce').start() all_reduce_parameters(gate_fused_tensors, hcg.get_expert_parallel_group()) all_reduce_parameters(reduce_fused_tensors, hcg.get_data_parallel_group()) timers('backward-params-all-reduce').stop() if args.use_pure_fp16: scaler.minimize(optimizer, loss) else: optimizer.step() learning_rate = optimizer.get_lr() if lr_scheduler is not None: lr_scheduler.step() optimizer.clear_grad() if global_step % args.logging_freq == 0: avg_loss = loss.numpy() speed = args.logging_freq / (time.time() - tic_train) if args.gate != "naive" and args.balance_loss_weight: bal_loss = bal_loss.numpy() avg_loss -= bal_loss else: bal_loss = -1 logger.info( "global step %d, epoch: %d, batch: %d, loss: %.9f, bal_loss: %.9f, speed: %.2f step/s, ips: %.0f tokens/s, learning rate: %.5e" % (global_step, epoch, step, avg_loss, bal_loss, speed, speed * default_global_tokens_num, learning_rate)) log_writer.add_scalar("loss", float(loss), global_step) log_writer.add_scalar("learning_rate", learning_rate, global_step) tic_train = time.time() timer_log(args.logging_freq) if (global_step % args.save_steps == 0 or global_step >= args.max_steps): loss_scale = scaler._scale if args.use_pure_fp16 else None save_checkpoint(args, global_step, model, optimizer, lr_scheduler, tokenizer, loss_scale, dp_rank, mp_rank, pp_rank, pass_num, file_id, epoch) print( "save checkpoint for step_{} successfully...loss_scale = {}" .format(global_step, loss_scale)) if global_step % args.eval_freq == 0: # Since the valid data broardcast to all devices, we do evaluate on all device. run_evaluate(args, valid_data_loader, model, criterion, args.eval_iters, log_writer, global_step, epoch, "valid") if global_step >= args.max_steps: run_evaluate(args, test_data_loader, model, criterion, args.test_iters, log_writer, global_step, epoch, "test") logger.info("The training process is complete.") del train_data_loader return # to record sum of the length of train_data_loader that has been read. pass_num += len(train_data_loader()) del train_data_loader