def do_train(args): # Initialize the paddle and paddle fleet execute environment paddle.enable_static() fleet.init(is_collective=True) # Create the random seed for the worker random.seed(args.seed) np.random.seed(args.seed) paddle.seed(args.seed) get_rng_state_tracker().add('global_seed', args.seed) get_rng_state_tracker().add('local_seed', args.seed + fleet.worker_index() + 2021) assert args.device in [ "cpu", "gpu", "xpu" ], "Invalid device! Available device should be cpu, gpu, or xpu." place = paddle.set_device(args.device) worker_num = fleet.worker_num() worker_index = fleet.worker_index() topo = Topology(device_rank=worker_index, world_size=worker_num, dp_degree=args.dp_degree, pp_degree=args.pp_degree, sharding_degree=args.sharding_degree, mp_degree=args.mp_degree) logger.info("The topo of hybrid parallelism:\n{}".format(topo)) dist_strategy = dist_optimizer(args, topo) # Create log write, train results show on last card of pipeline. if topo.is_last: log_writer_path = os.path.join( args.output_dir, "train_log", "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format( args.model_name_or_path, args.global_batch_size, args.use_amp, args.use_recompute, worker_index).lower()) if os.path.exists(log_writer_path): import shutil shutil.rmtree(log_writer_path) log_writer = LogWriter(log_writer_path) # Define the input data in the static mode model_class, tokenizer_class = MODEL_CLASSES[args.model_type] pretrained_models_list = list( model_class.pretrained_init_configuration.keys()) data_file = get_train_data_file(args) main_program = paddle.static.default_main_program() startup_program = paddle.static.default_startup_program() with paddle.static.program_guard(main_program, startup_program): with paddle.utils.unique_name.guard(): with paddle.static.device_guard('gpu:0'): data_holders = create_data_holder(args) [tokens, loss_mask, attention_mask, position_ids, labels] = data_holders tokenizer = tokenizer_class.from_pretrained( args.model_name_or_path) eos_id = tokenizer.eos_token_id train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset( args, data_file, data_world_size=topo.data_info.size, data_world_rank=topo.data_info.rank, eos_id=eos_id, max_seq_len=args.max_seq_len, places=paddle.static.cuda_places(), data_holders=data_holders, pipeline_mode=False, ) if args.model_name_or_path in pretrained_models_list: model_config = model_class.pretrained_init_configuration[ args.model_name_or_path] model_config[ "hidden_dropout_prob"] = args.hidden_dropout_prob model_config[ "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob model_config["topo"] = topo model = guard(f'gpu:{args.pp_degree -1}')( GPTForPretraining)( guard(f'gpu:0')(GPTModel)(**model_config)) else: model, _ = GPTForPretraining.from_pretrained( args.model_name_or_path, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args. attention_probs_dropout_prob, topo=topo) # Create the model for the gpt pretrain preds = model(tokens, position_ids, attention_mask) criterion = guard(f'gpu:{args.pp_degree -1}')( GPTPretrainingCriterion)(topo) loss = criterion(preds, labels, loss_mask) # Create the learning_rate sheduler and optimizer if args.decay_steps is None: args.decay_steps = args.max_steps warmup_step = args.warmup_rate * args.decay_steps # TODO @ZHUI Use paddle network to support lr scheduler lr_scheduler = lr.CosineAnnealingWithWarmupDecay( max_lr=args.max_lr, min_lr=args.min_lr, warmup_step=warmup_step, decay_step=args.decay_steps) clip = None if args.grad_clip > 0: clip = paddle.fluid.clip.GradientClipByGlobalNorm( clip_norm=args.grad_clip) decay_param = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon, grad_clip=clip, weight_decay=args.weight_decay, apply_decay_param_fun=lambda x: x in decay_param) # alias optimizer.apply_optimize = optimizer._apply_optimize if args.use_recompute: dist_strategy.recompute = True dist_strategy.recompute_configs = { "checkpoints": model.gpt.checkpoints } # Use the fleet api to compile the distributed optimizer optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy) optimizer.minimize(loss) logger.info(f'final strategy: {fleet._final_strategy()}') logger.info("The training meta optimizer is/are %s" % fleet._get_applied_meta_list()) program_desc_dir = os.path.join(args.output_dir, "program_desc") if not os.path.isdir(program_desc_dir): os.mkdir(program_desc_dir) with open(program_desc_dir + "/main_program.txt.%d" % worker_index, 'w') as f: f.write(str(main_program)) with open(program_desc_dir + "/startup_program.txt.%d" % worker_index, 'w') as f: f.write(str(startup_program)) # Define the Executor for running the static model exe = paddle.static.Executor(place) exe.run(startup_program) test_program = main_program.clone(for_test=True) if args.model_name_or_path not in pretrained_models_list: logger.info("Try to load checkpoint from %s " % args.model_name_or_path) dygrah_path = os.path.join(args.model_name_or_path, "model_state.pdparams") static_path = os.path.join(args.model_name_or_path, "static_vars") flag_loaded = False if os.path.exists(static_path): if args.mp_degree > 1: logger.warning("MP should init with dygraph params") else: logger.info("Loading parameters from %s" % static_path) paddle.static.load(main_program, static_path, exe) flag_loaded = True if not flag_loaded and os.path.exists(dygrah_path): if args.sharding_degree > 1: logger.warning("Sharding should init with static vars") else: logger.info("Loading parameters from %s" % dygrah_path) init_static_with_params( model, paddle.load(dygrah_path, return_numpy=True), topo, main_program) flag_loaded = True if not flag_loaded: logger.error("No checkpoint load.") global_step = 0 tic_train = time.time() epoch = 0 learning_rate = main_program.global_block().vars["learning_rate_0"] while True: fetchs = [] if topo.is_last: fetchs = [loss, learning_rate] # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader # many times. and start a new random dataloader. valid_data_loader = valid_data_loader() test_data_loader = test_data_loader() for step, batch in enumerate(train_data_loader()): global_step += 1 ret = exe.run(main_program, feed=batch, fetch_list=fetchs, use_program_cache=True) # In the new 2.0 api, must call this function to change the learning_rate lr_scheduler.step() if global_step % args.logging_freq == 0: if topo.is_last: loss_return, lr_return = ret speed = args.logging_freq / (time.time() - tic_train) logger.info( "global step %d, epoch: %d, batch: %d, loss: %.9f, speed: %.2f steps/s, ips: %.0f tokens/s, learning rate: %.5e" % (global_step, epoch, step, loss_return[0], speed, speed * args.global_batch_size * args.max_seq_len, lr_return[0])) log_writer.add_scalar("loss", loss_return[0], global_step) log_writer.add_scalar("learning_rate", lr_return[0], global_step) tic_train = time.time() if args.check_accuracy: if global_step >= args.max_steps: return else: continue if global_step % args.eval_freq == 0: # TODO, check the input data of validation eval_fetch = [] if topo.is_last: eval_fetch = [loss] run_evaluate(valid_data_loader, exe, test_program, args.eval_iters, log_writer, global_step, args, epoch, topo.is_last, eval_fetch, "valid") tic_train = time.time() if global_step % args.save_steps == 0 or global_step >= args.max_steps: output_dir = os.path.join(args.output_dir, "model_%d" % global_step) logger.debug("saving models to {}".format(output_dir)) save_persistables(exe, os.path.join(output_dir, "static_vars"), main_program) if global_step == args.save_steps: model.init_config["init_args"][0].init_config.pop( "topo", None) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) tic_train = time.time() if global_step >= args.max_steps: eval_fetch = [] if topo.is_last: eval_fetch = [loss] run_evaluate(test_data_loader, exe, test_program, args.test_iters, log_writer, global_step, args, epoch, topo.is_last, eval_fetch, "test") del train_data_loader return epoch += 1
def do_train(args): paddle.set_device(args.device) if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() worker_index = paddle.distributed.get_rank() worker_num = paddle.distributed.get_world_size() local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0)) set_seed(args) # Now, we only support data parallel in dygraph mode for now. topo = Topology( device_rank=worker_index, world_size=worker_num, dp_degree=worker_num) default_global_batch_size = topo.data_info.size * args.micro_batch_size default_global_tokens_num = default_global_batch_size * args.max_seq_len model_class, tokenizer_class = MODEL_CLASSES[args.model_type] tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) # Define log writer log_writer_path = os.path.join( args.output_dir, "train_log", "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format( args.model_name_or_path, args.micro_batch_size * topo.data_info.size, False, False, worker_index).lower()) if os.path.exists(log_writer_path): import shutil shutil.rmtree(log_writer_path) log_writer = LogWriter(log_writer_path) pretrained_models_list = list( model_class.pretrained_init_configuration.keys()) if args.model_name_or_path in pretrained_models_list: model_config = model_class.pretrained_init_configuration[ args.model_name_or_path] model_config["hidden_dropout_prob"] = args.hidden_dropout_prob model_config[ "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob model = GPTForPretraining(GPTModel(**model_config)) else: model = GPTForPretraining.from_pretrained( args.model_name_or_path, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob) # Create the critrion for the gpt model criterion = GPTPretrainingCriterion() if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) if args.decay_steps is None: args.decay_steps = args.max_steps warmup_step = args.warmup_rate * args.decay_steps lr_scheduler = None if args.lr_decay_style == "none": lr_scheduler = None elif args.lr_decay_style == "cosine": lr_scheduler = lr.CosineAnnealingWithWarmupDecay( max_lr=args.max_lr, min_lr=args.min_lr, warmup_step=warmup_step, decay_step=args.decay_steps) clip = None if args.grad_clip > 0: clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=args.grad_clip) # Generate parameter names needed to perform weight decay. # All bias and LayerNorm parameters are excluded. decay_params = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler if lr_scheduler is not None else args.max_lr, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon, parameters=model.parameters(), weight_decay=args.weight_decay, grad_clip=clip, apply_decay_param_fun=lambda x: x in decay_params) if args.use_amp: scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss) if args.model_name_or_path not in pretrained_models_list: logger.info("Try to load checkpoint from %s " % args.model_name_or_path) opt_path = os.path.join(args.model_name_or_path, "model_state.pdopt") if os.path.exists(opt_path): opt_dict = paddle.load(opt_path) optimizer.set_state_dict(opt_dict) else: logger.warning("No optimizer checkpoint file found in %s." % opt_path) global_step = 0 epoch = 0 tic_train = time.time() while True: files = get_train_data_file(args) files.sort() num_files = len(files) for f_id in range(num_files): data_file = files[f_id] train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset( args, [data_file], local_rank=local_rank, data_world_size=topo.data_info.size, data_world_rank=topo.data_info.rank, eos_id=tokenizer.eos_token_id) # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader # many times. and start a new random dataloader. valid_data_loader = valid_data_loader() test_data_loader = test_data_loader() # time count train_reader_cost = 0.0 train_run_cost = 0.0 reader_start = time.time() for step, batch in enumerate(train_data_loader()): train_reader_cost += time.time() - reader_start train_start = time.time() global_step += 1 tokens, loss_mask, attention_mask, position_ids, labels = batch loss_mask.stop_gradient = True attention_mask.stop_gradient = True with paddle.amp.auto_cast( args.use_amp, custom_white_list=["layer_norm", "softmax", "gelu"], custom_black_list=[ "reduce_sum", "c_softmax_with_cross_entropy", "c_embedding" ]): preds = model(tokens, position_ids, attention_mask) loss = criterion(preds, labels, loss_mask) if args.use_amp: scaler.scale(loss).backward() scaler.minimize(optimizer, loss) else: loss.backward() optimizer.step() if lr_scheduler is not None: lr_scheduler.step() optimizer.clear_grad() loss_numpy = loss.numpy() train_run_cost += time.time() - train_start # Profile for model benchmark profiler.add_profiler_step(args.profiler_options) if global_step % args.logging_freq == 0: speed = args.logging_freq / ( train_reader_cost + train_run_cost) avg_reader_cost = train_reader_cost / args.logging_freq logger.info( "global step %d, epoch: %d, batch: %d, loss: %.9f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, speed: %.2f step/s, ips: %.0f tokens/s, ips_per_card: %.0f tokens/s, learning rate: %.5e" % (global_step, epoch, step, loss_numpy, avg_reader_cost, 1. / speed, speed, speed * default_global_tokens_num, speed * default_global_tokens_num / worker_num, optimizer.get_lr())) log_writer.add_scalar("loss", loss_numpy, global_step) log_writer.add_scalar("learning_rate", optimizer.get_lr(), global_step) tic_train = time.time() train_reader_cost = 0.0 train_run_cost = 0.0 if args.check_accuracy: if global_step >= args.max_steps: return else: continue if global_step % args.eval_freq == 0: # Since the valid data broardcast to all devices, we do evaluate on all device. run_evaluate(valid_data_loader, model, criterion, args.eval_iters, log_writer, global_step, epoch, "valid") if global_step % args.save_steps == 0 or global_step >= args.max_steps: if worker_index == 0: output_dir = os.path.join(args.output_dir, "model_%d" % global_step) if not os.path.exists(output_dir): os.makedirs(output_dir) # Need better way to get inner model of DataParallel model_to_save = model._layers if isinstance( model, paddle.DataParallel) else model logger.info("Save model to %s" % output_dir) model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) paddle.save(optimizer.state_dict(), os.path.join(output_dir, "model_state.pdopt")) if global_step >= args.max_steps: run_evaluate(test_data_loader, model, criterion, args.test_iters, log_writer, global_step, epoch, "test") logger.info("The training process is complete.") del train_data_loader return reader_start = time.time() del train_data_loader
def do_train(args): # Initialize the paddle and paddle fleet execute environment paddle.enable_static() fleet.init(is_collective=True) # Create the random seed for the worker random.seed(args.seed) np.random.seed(args.seed) paddle.seed(args.seed) get_rng_state_tracker().add('global_seed', args.seed) get_rng_state_tracker().add('local_seed', args.seed + fleet.worker_index() + 2021) assert args.device in [ "cpu", "gpu", "xpu" ], "Invalid device! Available device should be cpu, gpu, or xpu." place = paddle.set_device(args.device) worker_num = fleet.worker_num() worker_index = fleet.worker_index() assert args.dp_degree * args.sharding_degree * args.mp_degree * args.pp_degree == worker_num, \ "The product of degree num should be equal to worker_num." topo = Topology(device_rank=worker_index, world_size=worker_num, dp_degree=args.dp_degree, pp_degree=args.pp_degree, sharding_degree=args.sharding_degree, mp_degree=args.mp_degree) logger.info("The topo of hybrid parallelism:\n{}".format(topo)) dist_strategy = dist_optimizer(args, topo) # Create log write, train results show on last card of pipeline. if topo.is_last: log_writer_path = os.path.join( args.output_dir, "train_log", "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format( args.model_name_or_path, args.global_batch_size, args.use_amp, args.use_recompute, worker_index).lower()) # if os.path.exists(log_writer_path): # shutil.rmtree(log_writer_path) log_writer = LogWriter(log_writer_path) # Define the input data in the static mode base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[ args.model_type] pretrained_models_list = list( model_class.pretrained_init_configuration.keys()) # load config in checkpoint global_step = 0 consumed_samples = 0 checkpoint_dir = os.path.join(args.output_dir, "model_last") if os.path.exists(checkpoint_dir): if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")): with open(os.path.join(checkpoint_dir, "./config.yml"), "r") as f: step_config = yaml.load(f, Loader=yaml.FullLoader) assert step_config[ "global_batch_size"] == args.global_batch_size, "Please ensure checkpoint global batch size is the same. Folder: {}".format( checkpoint_dir) consumed_samples = step_config["consumed_samples"] global_step = step_config["global_step"] data_file = get_train_data_file(args) main_program = paddle.static.default_main_program() startup_program = paddle.static.default_startup_program() with paddle.static.program_guard(main_program, startup_program): data_holders = create_data_holder(args) # 0. input_ids, # 1. segment_ids, # 2. input_mask, # 3. masked_lm_positions, # 4. masked_lm_labels, # 5. next_sentence_labels [ input_ids, segment_ids, input_mask, masked_lm_positions, masked_lm_labels, next_sentence_labels ] = data_holders tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset( args, data_file, tokenizer, data_world_size=topo.data_info.size, data_world_rank=topo.data_info.rank, max_seq_len=args.max_seq_len, places=paddle.static.cuda_places(), data_holders=data_holders, current_step=global_step) fleet.init(is_collective=True) if args.model_name_or_path in pretrained_models_list: model_config = model_class.pretrained_init_configuration[ args.model_name_or_path] if model_config["vocab_size"] % 8 != 0: model_config["vocab_size"] += 8 - (model_config["vocab_size"] % 8) model_config["hidden_dropout_prob"] = args.hidden_dropout_prob model_config[ "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob model = model_class(base_class(**model_config)) else: model, _ = model_class.from_pretrained( args.model_name_or_path, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, ) # Create the model for the gpt pretrain prediction_scores, seq_relationship_score = model( input_ids=input_ids, token_type_ids=segment_ids, position_ids=None, attention_mask=input_mask, masked_positions=masked_lm_positions) criterion = criterion_class(with_nsp_loss=args.binary_head) if args.binary_head: lm_loss, sop_loss = criterion(prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels) loss = lm_loss + sop_loss else: loss = criterion(prediction_scores, seq_relationship_score, masked_lm_labels) # Create the learning_rate sheduler and optimizer if args.decay_steps is None: args.decay_steps = args.max_steps # lr_scheduler = CosineAnnealingWithWarmupDecay( # max_lr=args.max_lr, # min_lr=args.min_lr, # warmup_step=args.warmup_rate * args.max_steps, # decay_step=args.decay_steps, last_epoch=global_step) lr_scheduler = LinearDecayWithWarmup(args.max_lr, args.max_steps, args.warmup_rate, last_epoch=global_step) clip = None if args.grad_clip > 0: clip = paddle.fluid.clip.GradientClipByGlobalNorm( clip_norm=args.grad_clip) decay_param = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] logger.info("Using paddle.optimizer.AdamW.") optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon, grad_clip=clip, weight_decay=args.weight_decay, apply_decay_param_fun=lambda x: x in decay_param) # alias optimizer.apply_optimize = optimizer._apply_optimize # if args.use_recompute: # dist_strategy.recompute = True # dist_strategy.recompute_configs = { # "checkpoints": model.bert.checkpoints # } # Use the fleet api to compile the distributed optimizer optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy) optimizer.minimize(loss) logger.info(f'final strategy: {fleet._final_strategy()}') logger.info("The training meta optimizer is/are %s" % fleet._get_applied_meta_list()) program_desc_dir = os.path.join(args.output_dir, "program_desc") if not os.path.isdir(program_desc_dir): os.mkdir(program_desc_dir) with open(program_desc_dir + "/main_program.txt.%d" % worker_index, 'w') as f: f.write(str(main_program)) with open(program_desc_dir + "/startup_program.txt.%d" % worker_index, 'w') as f: f.write(str(startup_program)) # Define the Executor for running the static model exe = paddle.static.Executor(place) exe.run(startup_program) test_program = main_program.clone(for_test=True) if args.model_name_or_path not in pretrained_models_list: logger.info("Try to load checkpoint from %s " % args.model_name_or_path) dygrah_path = os.path.join(args.model_name_or_path, "model_state.pdparams") static_path = os.path.join(args.model_name_or_path, "static_vars") flag_loaded = False if os.path.exists(static_path): if args.mp_degree > 1: logger.warning("MP should init with dygraph params") else: logger.info("Loading parameters from %s" % static_path) paddle.static.load(main_program, static_path, exe) flag_loaded = True if not flag_loaded and os.path.exists(dygrah_path): if args.sharding_degree > 1: logger.warning("Sharding should init with static vars") else: logger.info("Loading parameters from %s" % dygrah_path) init_static_with_params( model, paddle.load(dygrah_path, return_numpy=True), topo, main_program) flag_loaded = True if not flag_loaded: logger.error("No checkpoint load.") # load checkpoint vars if os.path.exists(checkpoint_dir): if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")): paddle.static.load(main_program, os.path.join(checkpoint_dir, "static_vars"), exe) fetch_loss_vars = collections.OrderedDict() fetch_other_vars = collections.OrderedDict() fetch_loss_vars["loss"] = loss if args.binary_head: fetch_loss_vars["lm_loss"] = lm_loss fetch_loss_vars["sop_loss"] = sop_loss fetch_other_vars["learning_rate"] = main_program.global_block( ).vars["learning_rate_0"] additional_vars = collections.OrderedDict() if args.use_amp: for key in ["loss_scaling", "num_good_steps", "num_bad_steps"]: additional_vars[key] = main_program.global_block().vars[key + "_0"] tic_train = time.time() while True: fetchs = [] fetchs_keys = [] if topo.is_last: fetchs = list(fetch_loss_vars.values()) + list( fetch_other_vars.values()) + list(additional_vars.values()) fetchs_keys = list(fetch_loss_vars.keys()) + list( fetch_other_vars.keys()) + list(additional_vars.keys()) # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader # many times. and start a new random dataloader. valid_data_loader = valid_data_loader() test_data_loader = test_data_loader() for step, batch in enumerate(train_data_loader()): ret = exe.run(main_program, feed=batch, fetch_list=fetchs, use_program_cache=True) # Skip for accumulate_steps in global step if (step + 1) % args.accumulate_steps != 0: continue global_step += 1 # In the new 2.0 api, must call this function to change the learning_rate lr_scheduler.step() if global_step % args.logging_freq == 0: if topo.is_last: res = collections.defaultdict(float) for k, v in zip(fetchs_keys, ret): res[k] = v[0] speed = args.logging_freq / (time.time() - tic_train) loss_info = "loss: %.6f, lm_loss: %.6f, sop_loss: %.6f" loss_info = ", ".join([ "{}: {:.6f}".format(k, res[k]) for k in fetch_loss_vars.keys() ]) common_loginfo = "global step %d, %s, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % ( global_step, loss_info, speed, speed * args.global_batch_size, res["learning_rate"]) additional_loginfo = ", ".join([ "{}: {}".format(k, res[k]) for k in additional_vars.keys() ]) if additional_loginfo: common_loginfo += ", " + additional_loginfo logger.info(common_loginfo) for k, v in res.items(): log_writer.add_scalar(k, v, global_step) tic_train = time.time() #if args.check_accuracy: # if global_step >= args.max_steps: # return # else: # continue if global_step % args.eval_freq == 0: # TODO, check the input data of validation eval_fetch = collections.OrderedDict() if topo.is_last: eval_fetch["loss"] = loss if args.binary_head: eval_fetch["lm_loss"] = lm_loss eval_fetch["sop_loss"] = sop_loss run_evaluate(valid_data_loader, exe, test_program, args.eval_iters, log_writer, global_step, args, topo.is_last, eval_fetch, "valid") tic_train = time.time() if global_step % args.save_steps == 0 or global_step >= args.max_steps: output_dir = os.path.join(args.output_dir, "model_%d" % global_step) logger.debug("saving models to {}".format(output_dir)) save_persistables(exe, os.path.join(output_dir, "static_vars"), main_program) if global_step == args.save_steps: model.init_config["init_args"][0].init_config.pop( "topo", None) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) tic_train = time.time() if global_step % args.checkpoint_steps == 0: output_dir = os.path.join(args.output_dir, "model_last") if worker_index == 0: if not os.path.exists(output_dir): os.mkdir(output_dir) output_dir_bak = os.path.join(args.output_dir, "model_last_bak") if os.path.exists(output_dir): if os.path.exists(output_dir_bak): shutil.rmtree(output_dir_bak) shutil.move(output_dir, output_dir_bak) os.mkdir(output_dir) step_config = { "model_name": args.model_name_or_path, "global_step": global_step, "global_batch_size": args.global_batch_size, "consumed_samples": global_step * args.global_batch_size, } with open(os.path.join(output_dir, "config.yml"), "w") as f: yaml.dump(step_config, f, encoding='utf-8', allow_unicode=True) fleet.barrier_worker() logger.debug("saving models to {}".format(output_dir)) if args.sharding_degree <= 1: # Save on the first worker by default. if worker_index == 0: paddle.static.save( main_program, os.path.join(output_dir, "static_vars")) else: # Use save_persistables in sharding, but more slower save_persistables(exe, os.path.join(output_dir, "static_vars"), main_program) if global_step >= args.max_steps: eval_fetch = collections.OrderedDict() if topo.is_last: eval_fetch["loss"] = loss if args.binary_head: eval_fetch["lm_loss"] = lm_loss eval_fetch["sop_loss"] = sop_loss run_evaluate(test_data_loader, exe, test_program, args.test_iters, log_writer, global_step, args, topo.is_last, eval_fetch, "test") del train_data_loader return
def do_generation(args): # Initialize the paddle and paddle fleet execute environment paddle.enable_static() assert args.dp_degree == 1, "Data parallel is not supported in inference" assert args.sharding_degree == 1, "Sharding parallel is temporarily not supported in inference" assert args.pp_degree == 1, "Pipeline parallel will be supported later" if args.mp_degree == 1: args.mp_degree = paddle.distributed.get_world_size() else: assert args.mp_degree == paddle.distributed.get_world_size(), \ "If mp_degree is specified, the size must be the same as world_size" strategy = fleet.DistributedStrategy() strategy.tensor_parallel = True strategy.tensor_parallel_configs = { "tensor_parallel_degree": args.mp_degree } fleet.init(is_collective=True, strategy=strategy) # temp use dynamic init, use HybridParallelInferenceHelper in future? paddle.distributed.init_parallel_env() # Create the random seed for the worker random.seed(args.seed) np.random.seed(args.seed) paddle.seed(args.seed) get_rng_state_tracker().add('global_seed', args.seed) get_rng_state_tracker().add('local_seed', args.seed + fleet.worker_index() + 2021) if args.use_amp and args.amp_level == "O2": assert (args.mp_degree == 1 and args.pp_degree == 1 ), "When amp level is O2, mp_degree and pp_degree should be 1." assert (args.use_sharding == False ), "When amp level is O2, use_sharding should be False." assert args.device in [ "cpu", "gpu", "xpu" ], "Invalid device! Available device should be cpu, gpu, or xpu." place = paddle.set_device(args.device) worker_num = fleet.worker_num() worker_index = fleet.worker_index() local_rank = 0 if fleet.local_rank() is None else int(fleet.local_rank()) topo = Topology( device_rank=worker_index, world_size=worker_num, dp_degree=args.dp_degree, pp_degree=args.pp_degree, sharding_degree=args.sharding_degree, mp_degree=args.mp_degree) logger.info("The topo of hybrid parallelism:\n{}".format(topo)) model_class, tokenizer_class = MODEL_CLASSES[args.model_type] pretrained_models_list = list( model_class.pretrained_init_configuration.keys()) data_file = get_data_file(args) main_program = paddle.static.default_main_program() startup_program = paddle.static.default_startup_program() with paddle.static.program_guard(main_program, startup_program): with paddle.utils.unique_name.guard(): with paddle.static.device_guard('gpu:0'): feeds = create_data_holder(args) tokenizer = tokenizer_class.from_pretrained( args.model_name_or_path) eos_id = tokenizer.eos_token_id _, _, test_data_loader = create_pretrained_dataset( args, data_file, local_rank=local_rank, data_world_size=topo.data_info.size, data_world_rank=topo.data_info.rank, eos_id=eos_id, max_seq_len=args.max_seq_len, places=paddle.static.cuda_places(), data_holders=feeds, pipeline_mode=False) if args.model_name_or_path in pretrained_models_list: model_config = model_class.pretrained_init_configuration[ args.model_name_or_path] model_config[ "hidden_dropout_prob"] = args.hidden_dropout_prob model_config[ "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob model_config["topo"] = topo model_config["fuse"] = args.fuse model = GPTForGeneration( GPTModel(**model_config), max_length=args.max_dec_len, decoding_strategy=args.decoding_strategy, temperature=args.temperature, top_k=args.topk, top_p=args.topp, eos_id=eos_id, fuse=args.fuse) else: logger.error("No checkpoint load.") model.eval() ins = {v.name: v for v in feeds} preds = model(ins) # Define the Executor for running the static model exe = paddle.static.Executor(place) exe.run(startup_program) main_program = main_program.clone(for_test=True) model_urls = model.pretrained_resource_files_map['model_state'] model_path = args.model_name_or_path if model_path in pretrained_models_list and model_path in model_urls: flag_loaded = False from paddle.utils.download import get_weights_path_from_url dygraph_path = get_weights_path_from_url(model_urls[model_path]) if os.path.exists(dygraph_path): if args.sharding_degree > 1: logger.warning("Sharding should init with static vars") else: logger.info("Loading parameters from %s" % dygraph_path) init_static_with_params( model, paddle.load( dygraph_path, return_numpy=True), topo, main_program) flag_loaded = True if not flag_loaded: logger.error("No checkpoint load.") global_step = 0 epoch = 0 fetchs = [preds] ### check resutls text = [ "Question: Where is the capital of China? Answer:", "Question:Who is the CEO of Apple? Answer:" ] inputs = tokenizer( text, padding=True, return_attention_mask=True, return_position_ids=True) ids = np.array(inputs["input_ids"]).reshape(len(text), -1).astype('int64') position_ids = np.array(inputs["position_ids"]).reshape(len(text), -1).astype('int64') attention_mask = np.array(inputs["attention_mask"]).reshape( len(text), -1).astype('float32') t_ids = paddle.fluid.core.Tensor() t_ids.set(ids, place) t_mask = paddle.fluid.core.Tensor() t_mask.set(attention_mask, place) t_pos = paddle.fluid.core.Tensor() t_pos.set(position_ids, place) feed_data = {'src_ids': t_ids, 'pos_ids': t_pos, 'input_mask': t_mask} ret = exe.run(main_program, feed=feed_data, fetch_list=fetchs) ret = np.array(ret[0]) for i in range(ret.shape[0]): o = [int(x) for x in ret[i]] ret_str = tokenizer.convert_ids_to_string(o) ret_str = text[i] + ret_str logger.info(ret_str) ################## for step, batch in enumerate(test_data_loader()): ret = exe.run(main_program, feed=batch, fetch_list=fetchs) if step == 5: break if args.save_inference_model_then_exist: save_inference_model_dir = 'inference_model_pp{pp_degree}mp{mp_degree}'.format( pp_degree=args.pp_degree, mp_degree=args.mp_degree) inference_save_path = os.path.join(save_inference_model_dir, 'rank_' + str(fleet.worker_index()), 'step_' + str(0)) print("saving inference models to {}".format(inference_save_path)) feed_names = [v.name for v in feeds] fetchs_names = [v.name for v in fetchs] print('feeds: ', feed_names, 'fetches: ', fetchs_names) paddle.static.save_inference_model( inference_save_path, feeds, fetchs, exe, program=main_program)