def test_fleet_get_applied_optimizer(self): input_x = paddle.fluid.layers.data(name="x", shape=[32], dtype='float32') input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') cost = paddle.fluid.layers.cross_entropy(input=prediction, label=input_y) avg_cost = paddle.fluid.layers.mean(x=cost) fleet.init(is_collective=True) meta_list = fleet._get_applied_meta_list() graph_list = fleet._get_applied_graph_list() # not called minimize function self.assertEqual(len(meta_list), 0) self.assertEqual(len(graph_list), 0) strategy = fleet.DistributedStrategy() optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer.minimize(avg_cost) meta_list = fleet._get_applied_meta_list() graph_list = fleet._get_applied_graph_list() self.assertEqual(len(meta_list), 0) self.assertEqual(len(graph_list), 1)
def test_amp_recompute_lamb_optimizer(self): train_prog, startup_prog = fluid.Program(), fluid.Program() avg_cost, strategy = self.net(train_prog, startup_prog) self.set_strategy(strategy, 'amp') self.set_strategy(strategy, 'recompute') self.set_strategy(strategy, 'lamb') self.optimizer(avg_cost, strategy, train_prog, startup_prog, 'adam') applied_meta_list = fleet._get_applied_meta_list() applied_graph_list = fleet._get_applied_graph_list() print(applied_meta_list, applied_graph_list) self.assertEqual(len(applied_meta_list), 3) ops = [op.type for op in avg_cost.block.ops] outs = [ op.output('Out')[0] for op in avg_cost.block.ops if op.type == 'mul' ] self.assertIn('cast', ops) self.assertIn('check_finite_and_unscale', ops) # recompute self.assertIn('subprog', ''.join(outs)) # lamb self.assertIn('lamb', ops)
def test_fleet_amp_meta_optimizer_init(self): if not fluid.core.is_compiled_with_cuda(): return main_program = paddle.static.Program() startup_program = paddle.static.Program() role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) with paddle.static.program_guard(main_program, startup_program): input_x = paddle.static.data(name="x", shape=[None, 32], dtype='float32') input_y = paddle.static.data(name="y", shape=[None, 1], dtype='int64') cost = mlp(input_x, input_y) optimizer = paddle.optimizer.Momentum( learning_rate=0.001, momentum=0.9, weight_decay=fluid.regularizer.L2Decay(1e-4), multi_precision=True) strategy = paddle.distributed.fleet.DistributedStrategy() strategy.amp = True strategy.amp_configs = {'use_pure_fp16': True} strategy.gradient_merge = True strategy.gradient_merge_configs = {"k_steps": 2} optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer.minimize(cost) print(fleet._get_applied_meta_list()) loss_scale = optimizer.get_loss_scaling() place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) exe.run(startup_program) optimizer.amp_init(place) step = 3 for i in range(step): cost_val = exe.run(program=main_program, feed=gen_data(), fetch_list=[cost.name]) print(cost_val)
def do_train(args): # Initialize the paddle and paddle fleet execute environment paddle.enable_static() fleet.init(is_collective=True) # Create the random seed for the worker random.seed(args.seed) np.random.seed(args.seed) paddle.seed(args.seed) get_rng_state_tracker().add('global_seed', args.seed) get_rng_state_tracker().add('local_seed', args.seed + fleet.worker_index() + 2021) assert args.device in [ "cpu", "gpu", "xpu" ], "Invalid device! Available device should be cpu, gpu, or xpu." place = paddle.set_device(args.device) worker_num = fleet.worker_num() worker_index = fleet.worker_index() topo = Topology(device_rank=worker_index, world_size=worker_num, dp_degree=args.dp_degree, pp_degree=args.pp_degree, sharding_degree=args.sharding_degree, mp_degree=args.mp_degree) logger.info("The topo of hybrid parallelism:\n{}".format(topo)) dist_strategy = dist_optimizer(args, topo) # Create log write, train results show on last card of pipeline. if topo.is_last: log_writer_path = os.path.join( args.output_dir, "train_log", "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format( args.model_name_or_path, args.global_batch_size, args.use_amp, args.use_recompute, worker_index).lower()) if os.path.exists(log_writer_path): import shutil shutil.rmtree(log_writer_path) log_writer = LogWriter(log_writer_path) # Define the input data in the static mode model_class, tokenizer_class = MODEL_CLASSES[args.model_type] pretrained_models_list = list( model_class.pretrained_init_configuration.keys()) data_file = get_train_data_file(args) main_program = paddle.static.default_main_program() startup_program = paddle.static.default_startup_program() with paddle.static.program_guard(main_program, startup_program): with paddle.utils.unique_name.guard(): with paddle.static.device_guard('gpu:0'): data_holders = create_data_holder(args) [tokens, loss_mask, attention_mask, position_ids, labels] = data_holders tokenizer = tokenizer_class.from_pretrained( args.model_name_or_path) eos_id = tokenizer.eos_token_id train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset( args, data_file, data_world_size=topo.data_info.size, data_world_rank=topo.data_info.rank, eos_id=eos_id, max_seq_len=args.max_seq_len, places=paddle.static.cuda_places(), data_holders=data_holders, pipeline_mode=False, ) if args.model_name_or_path in pretrained_models_list: model_config = model_class.pretrained_init_configuration[ args.model_name_or_path] model_config[ "hidden_dropout_prob"] = args.hidden_dropout_prob model_config[ "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob model_config["topo"] = topo model = guard(f'gpu:{args.pp_degree -1}')( GPTForPretraining)( guard(f'gpu:0')(GPTModel)(**model_config)) else: model, _ = GPTForPretraining.from_pretrained( args.model_name_or_path, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args. attention_probs_dropout_prob, topo=topo) # Create the model for the gpt pretrain preds = model(tokens, position_ids, attention_mask) criterion = guard(f'gpu:{args.pp_degree -1}')( GPTPretrainingCriterion)(topo) loss = criterion(preds, labels, loss_mask) # Create the learning_rate sheduler and optimizer if args.decay_steps is None: args.decay_steps = args.max_steps warmup_step = args.warmup_rate * args.decay_steps # TODO @ZHUI Use paddle network to support lr scheduler lr_scheduler = lr.CosineAnnealingWithWarmupDecay( max_lr=args.max_lr, min_lr=args.min_lr, warmup_step=warmup_step, decay_step=args.decay_steps) clip = None if args.grad_clip > 0: clip = paddle.fluid.clip.GradientClipByGlobalNorm( clip_norm=args.grad_clip) decay_param = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon, grad_clip=clip, weight_decay=args.weight_decay, apply_decay_param_fun=lambda x: x in decay_param) # alias optimizer.apply_optimize = optimizer._apply_optimize if args.use_recompute: dist_strategy.recompute = True dist_strategy.recompute_configs = { "checkpoints": model.gpt.checkpoints } # Use the fleet api to compile the distributed optimizer optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy) optimizer.minimize(loss) logger.info(f'final strategy: {fleet._final_strategy()}') logger.info("The training meta optimizer is/are %s" % fleet._get_applied_meta_list()) program_desc_dir = os.path.join(args.output_dir, "program_desc") if not os.path.isdir(program_desc_dir): os.mkdir(program_desc_dir) with open(program_desc_dir + "/main_program.txt.%d" % worker_index, 'w') as f: f.write(str(main_program)) with open(program_desc_dir + "/startup_program.txt.%d" % worker_index, 'w') as f: f.write(str(startup_program)) # Define the Executor for running the static model exe = paddle.static.Executor(place) exe.run(startup_program) test_program = main_program.clone(for_test=True) if args.model_name_or_path not in pretrained_models_list: logger.info("Try to load checkpoint from %s " % args.model_name_or_path) dygrah_path = os.path.join(args.model_name_or_path, "model_state.pdparams") static_path = os.path.join(args.model_name_or_path, "static_vars") flag_loaded = False if os.path.exists(static_path): if args.mp_degree > 1: logger.warning("MP should init with dygraph params") else: logger.info("Loading parameters from %s" % static_path) paddle.static.load(main_program, static_path, exe) flag_loaded = True if not flag_loaded and os.path.exists(dygrah_path): if args.sharding_degree > 1: logger.warning("Sharding should init with static vars") else: logger.info("Loading parameters from %s" % dygrah_path) init_static_with_params( model, paddle.load(dygrah_path, return_numpy=True), topo, main_program) flag_loaded = True if not flag_loaded: logger.error("No checkpoint load.") global_step = 0 tic_train = time.time() epoch = 0 learning_rate = main_program.global_block().vars["learning_rate_0"] while True: fetchs = [] if topo.is_last: fetchs = [loss, learning_rate] # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader # many times. and start a new random dataloader. valid_data_loader = valid_data_loader() test_data_loader = test_data_loader() for step, batch in enumerate(train_data_loader()): global_step += 1 ret = exe.run(main_program, feed=batch, fetch_list=fetchs, use_program_cache=True) # In the new 2.0 api, must call this function to change the learning_rate lr_scheduler.step() if global_step % args.logging_freq == 0: if topo.is_last: loss_return, lr_return = ret speed = args.logging_freq / (time.time() - tic_train) logger.info( "global step %d, epoch: %d, batch: %d, loss: %.9f, speed: %.2f steps/s, ips: %.0f tokens/s, learning rate: %.5e" % (global_step, epoch, step, loss_return[0], speed, speed * args.global_batch_size * args.max_seq_len, lr_return[0])) log_writer.add_scalar("loss", loss_return[0], global_step) log_writer.add_scalar("learning_rate", lr_return[0], global_step) tic_train = time.time() if args.check_accuracy: if global_step >= args.max_steps: return else: continue if global_step % args.eval_freq == 0: # TODO, check the input data of validation eval_fetch = [] if topo.is_last: eval_fetch = [loss] run_evaluate(valid_data_loader, exe, test_program, args.eval_iters, log_writer, global_step, args, epoch, topo.is_last, eval_fetch, "valid") tic_train = time.time() if global_step % args.save_steps == 0 or global_step >= args.max_steps: output_dir = os.path.join(args.output_dir, "model_%d" % global_step) logger.debug("saving models to {}".format(output_dir)) save_persistables(exe, os.path.join(output_dir, "static_vars"), main_program) if global_step == args.save_steps: model.init_config["init_args"][0].init_config.pop( "topo", None) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) tic_train = time.time() if global_step >= args.max_steps: eval_fetch = [] if topo.is_last: eval_fetch = [loss] run_evaluate(test_data_loader, exe, test_program, args.test_iters, log_writer, global_step, args, epoch, topo.is_last, eval_fetch, "test") del train_data_loader return epoch += 1
def do_train(args): # Initialize the paddle and paddle fleet execute environment paddle.enable_static() fleet.init(is_collective=True) # Create the random seed for the worker random.seed(args.seed) np.random.seed(args.seed) paddle.seed(args.seed) get_rng_state_tracker().add('global_seed', args.seed) get_rng_state_tracker().add('local_seed', args.seed + fleet.worker_index() + 2021) assert args.device in [ "cpu", "gpu", "xpu" ], "Invalid device! Available device should be cpu, gpu, or xpu." place = paddle.set_device(args.device) worker_num = fleet.worker_num() worker_index = fleet.worker_index() assert args.dp_degree * args.sharding_degree * args.mp_degree * args.pp_degree == worker_num, \ "The product of degree num should be equal to worker_num." topo = Topology(device_rank=worker_index, world_size=worker_num, dp_degree=args.dp_degree, pp_degree=args.pp_degree, sharding_degree=args.sharding_degree, mp_degree=args.mp_degree) logger.info("The topo of hybrid parallelism:\n{}".format(topo)) dist_strategy = dist_optimizer(args, topo) # Create log write, train results show on last card of pipeline. if topo.is_last: log_writer_path = os.path.join( args.output_dir, "train_log", "{}_globalbsz_{}_amp_{}_recompute_{}_card_{}".format( args.model_name_or_path, args.global_batch_size, args.use_amp, args.use_recompute, worker_index).lower()) # if os.path.exists(log_writer_path): # shutil.rmtree(log_writer_path) log_writer = LogWriter(log_writer_path) # Define the input data in the static mode base_class, model_class, criterion_class, tokenizer_class = MODEL_CLASSES[ args.model_type] pretrained_models_list = list( model_class.pretrained_init_configuration.keys()) # load config in checkpoint global_step = 0 consumed_samples = 0 checkpoint_dir = os.path.join(args.output_dir, "model_last") if os.path.exists(checkpoint_dir): if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")): with open(os.path.join(checkpoint_dir, "./config.yml"), "r") as f: step_config = yaml.load(f, Loader=yaml.FullLoader) assert step_config[ "global_batch_size"] == args.global_batch_size, "Please ensure checkpoint global batch size is the same. Folder: {}".format( checkpoint_dir) consumed_samples = step_config["consumed_samples"] global_step = step_config["global_step"] data_file = get_train_data_file(args) main_program = paddle.static.default_main_program() startup_program = paddle.static.default_startup_program() with paddle.static.program_guard(main_program, startup_program): data_holders = create_data_holder(args) # 0. input_ids, # 1. segment_ids, # 2. input_mask, # 3. masked_lm_positions, # 4. masked_lm_labels, # 5. next_sentence_labels [ input_ids, segment_ids, input_mask, masked_lm_positions, masked_lm_labels, next_sentence_labels ] = data_holders tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) train_data_loader, valid_data_loader, test_data_loader = create_pretrained_dataset( args, data_file, tokenizer, data_world_size=topo.data_info.size, data_world_rank=topo.data_info.rank, max_seq_len=args.max_seq_len, places=paddle.static.cuda_places(), data_holders=data_holders, current_step=global_step) fleet.init(is_collective=True) if args.model_name_or_path in pretrained_models_list: model_config = model_class.pretrained_init_configuration[ args.model_name_or_path] if model_config["vocab_size"] % 8 != 0: model_config["vocab_size"] += 8 - (model_config["vocab_size"] % 8) model_config["hidden_dropout_prob"] = args.hidden_dropout_prob model_config[ "attention_probs_dropout_prob"] = args.attention_probs_dropout_prob model = model_class(base_class(**model_config)) else: model, _ = model_class.from_pretrained( args.model_name_or_path, hidden_dropout_prob=args.hidden_dropout_prob, attention_probs_dropout_prob=args.attention_probs_dropout_prob, ) # Create the model for the gpt pretrain prediction_scores, seq_relationship_score = model( input_ids=input_ids, token_type_ids=segment_ids, position_ids=None, attention_mask=input_mask, masked_positions=masked_lm_positions) criterion = criterion_class(with_nsp_loss=args.binary_head) if args.binary_head: lm_loss, sop_loss = criterion(prediction_scores, seq_relationship_score, masked_lm_labels, next_sentence_labels) loss = lm_loss + sop_loss else: loss = criterion(prediction_scores, seq_relationship_score, masked_lm_labels) # Create the learning_rate sheduler and optimizer if args.decay_steps is None: args.decay_steps = args.max_steps # lr_scheduler = CosineAnnealingWithWarmupDecay( # max_lr=args.max_lr, # min_lr=args.min_lr, # warmup_step=args.warmup_rate * args.max_steps, # decay_step=args.decay_steps, last_epoch=global_step) lr_scheduler = LinearDecayWithWarmup(args.max_lr, args.max_steps, args.warmup_rate, last_epoch=global_step) clip = None if args.grad_clip > 0: clip = paddle.fluid.clip.GradientClipByGlobalNorm( clip_norm=args.grad_clip) decay_param = [ p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] logger.info("Using paddle.optimizer.AdamW.") optimizer = paddle.optimizer.AdamW( learning_rate=lr_scheduler, beta1=args.adam_beta1, beta2=args.adam_beta2, epsilon=args.adam_epsilon, grad_clip=clip, weight_decay=args.weight_decay, apply_decay_param_fun=lambda x: x in decay_param) # alias optimizer.apply_optimize = optimizer._apply_optimize # if args.use_recompute: # dist_strategy.recompute = True # dist_strategy.recompute_configs = { # "checkpoints": model.bert.checkpoints # } # Use the fleet api to compile the distributed optimizer optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy) optimizer.minimize(loss) logger.info(f'final strategy: {fleet._final_strategy()}') logger.info("The training meta optimizer is/are %s" % fleet._get_applied_meta_list()) program_desc_dir = os.path.join(args.output_dir, "program_desc") if not os.path.isdir(program_desc_dir): os.mkdir(program_desc_dir) with open(program_desc_dir + "/main_program.txt.%d" % worker_index, 'w') as f: f.write(str(main_program)) with open(program_desc_dir + "/startup_program.txt.%d" % worker_index, 'w') as f: f.write(str(startup_program)) # Define the Executor for running the static model exe = paddle.static.Executor(place) exe.run(startup_program) test_program = main_program.clone(for_test=True) if args.model_name_or_path not in pretrained_models_list: logger.info("Try to load checkpoint from %s " % args.model_name_or_path) dygrah_path = os.path.join(args.model_name_or_path, "model_state.pdparams") static_path = os.path.join(args.model_name_or_path, "static_vars") flag_loaded = False if os.path.exists(static_path): if args.mp_degree > 1: logger.warning("MP should init with dygraph params") else: logger.info("Loading parameters from %s" % static_path) paddle.static.load(main_program, static_path, exe) flag_loaded = True if not flag_loaded and os.path.exists(dygrah_path): if args.sharding_degree > 1: logger.warning("Sharding should init with static vars") else: logger.info("Loading parameters from %s" % dygrah_path) init_static_with_params( model, paddle.load(dygrah_path, return_numpy=True), topo, main_program) flag_loaded = True if not flag_loaded: logger.error("No checkpoint load.") # load checkpoint vars if os.path.exists(checkpoint_dir): if os.path.isfile(os.path.join(checkpoint_dir, "./config.yml")): paddle.static.load(main_program, os.path.join(checkpoint_dir, "static_vars"), exe) fetch_loss_vars = collections.OrderedDict() fetch_other_vars = collections.OrderedDict() fetch_loss_vars["loss"] = loss if args.binary_head: fetch_loss_vars["lm_loss"] = lm_loss fetch_loss_vars["sop_loss"] = sop_loss fetch_other_vars["learning_rate"] = main_program.global_block( ).vars["learning_rate_0"] additional_vars = collections.OrderedDict() if args.use_amp: for key in ["loss_scaling", "num_good_steps", "num_bad_steps"]: additional_vars[key] = main_program.global_block().vars[key + "_0"] tic_train = time.time() while True: fetchs = [] fetchs_keys = [] if topo.is_last: fetchs = list(fetch_loss_vars.values()) + list( fetch_other_vars.values()) + list(additional_vars.values()) fetchs_keys = list(fetch_loss_vars.keys()) + list( fetch_other_vars.keys()) + list(additional_vars.keys()) # Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader # many times. and start a new random dataloader. valid_data_loader = valid_data_loader() test_data_loader = test_data_loader() for step, batch in enumerate(train_data_loader()): ret = exe.run(main_program, feed=batch, fetch_list=fetchs, use_program_cache=True) # Skip for accumulate_steps in global step if (step + 1) % args.accumulate_steps != 0: continue global_step += 1 # In the new 2.0 api, must call this function to change the learning_rate lr_scheduler.step() if global_step % args.logging_freq == 0: if topo.is_last: res = collections.defaultdict(float) for k, v in zip(fetchs_keys, ret): res[k] = v[0] speed = args.logging_freq / (time.time() - tic_train) loss_info = "loss: %.6f, lm_loss: %.6f, sop_loss: %.6f" loss_info = ", ".join([ "{}: {:.6f}".format(k, res[k]) for k in fetch_loss_vars.keys() ]) common_loginfo = "global step %d, %s, speed: %.2f steps/s, ips: %.2f seqs/s, learning rate: %.5e" % ( global_step, loss_info, speed, speed * args.global_batch_size, res["learning_rate"]) additional_loginfo = ", ".join([ "{}: {}".format(k, res[k]) for k in additional_vars.keys() ]) if additional_loginfo: common_loginfo += ", " + additional_loginfo logger.info(common_loginfo) for k, v in res.items(): log_writer.add_scalar(k, v, global_step) tic_train = time.time() #if args.check_accuracy: # if global_step >= args.max_steps: # return # else: # continue if global_step % args.eval_freq == 0: # TODO, check the input data of validation eval_fetch = collections.OrderedDict() if topo.is_last: eval_fetch["loss"] = loss if args.binary_head: eval_fetch["lm_loss"] = lm_loss eval_fetch["sop_loss"] = sop_loss run_evaluate(valid_data_loader, exe, test_program, args.eval_iters, log_writer, global_step, args, topo.is_last, eval_fetch, "valid") tic_train = time.time() if global_step % args.save_steps == 0 or global_step >= args.max_steps: output_dir = os.path.join(args.output_dir, "model_%d" % global_step) logger.debug("saving models to {}".format(output_dir)) save_persistables(exe, os.path.join(output_dir, "static_vars"), main_program) if global_step == args.save_steps: model.init_config["init_args"][0].init_config.pop( "topo", None) model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) tic_train = time.time() if global_step % args.checkpoint_steps == 0: output_dir = os.path.join(args.output_dir, "model_last") if worker_index == 0: if not os.path.exists(output_dir): os.mkdir(output_dir) output_dir_bak = os.path.join(args.output_dir, "model_last_bak") if os.path.exists(output_dir): if os.path.exists(output_dir_bak): shutil.rmtree(output_dir_bak) shutil.move(output_dir, output_dir_bak) os.mkdir(output_dir) step_config = { "model_name": args.model_name_or_path, "global_step": global_step, "global_batch_size": args.global_batch_size, "consumed_samples": global_step * args.global_batch_size, } with open(os.path.join(output_dir, "config.yml"), "w") as f: yaml.dump(step_config, f, encoding='utf-8', allow_unicode=True) fleet.barrier_worker() logger.debug("saving models to {}".format(output_dir)) if args.sharding_degree <= 1: # Save on the first worker by default. if worker_index == 0: paddle.static.save( main_program, os.path.join(output_dir, "static_vars")) else: # Use save_persistables in sharding, but more slower save_persistables(exe, os.path.join(output_dir, "static_vars"), main_program) if global_step >= args.max_steps: eval_fetch = collections.OrderedDict() if topo.is_last: eval_fetch["loss"] = loss if args.binary_head: eval_fetch["lm_loss"] = lm_loss eval_fetch["sop_loss"] = sop_loss run_evaluate(test_data_loader, exe, test_program, args.test_iters, log_writer, global_step, args, topo.is_last, eval_fetch, "test") del train_data_loader return
def train(args): log.info("pretraining start") profile = False place = fluid.CUDAPlace(int(os.environ.get('FLAGS_selected_gpus', 0))) # set seed random.seed(args.seed) np.random.seed(args.seed) paddle.seed(args.seed) get_rng_state_tracker().add('global_seed', args.seed) get_rng_state_tracker().add('local_seed', args.seed + fleet.worker_index() + 2021) # define execution strategy exec_strategy = fluid.ExecutionStrategy() exec_strategy.num_threads = 2 exec_strategy.num_iteration_per_drop_scope = 1 # define distribution strategy dist_strategy = fleet.DistributedStrategy() dist_strategy.execution_strategy = exec_strategy dist_strategy.nccl_comm_num = 3 if args.use_recompute: log.info("using recompute.") dist_strategy.recompute = args.use_recompute dist_strategy.sharding = args.use_sharding dist_strategy.pipeline = args.num_pp > 1 # define topology structure for dp/pp/mp topo = Topology(rank=fleet.worker_index(), world_size=fleet.worker_num(), dp=args.num_dp, pp=args.num_pp, sharding=args.num_sharding, mp=args.num_mp) is_last = False if topo.pp.rank == (topo.pp.size - 1): is_last = True dp_sharding_rank = topo.dp.rank * topo.sharding.size + topo.sharding.rank dp_worldsize = topo.dp.size * topo.sharding.size bsz_per_dp = args.global_bsz // dp_worldsize micro_bsz = args.micro_bsz assert args.global_bsz % micro_bsz == 0, f"cannot do gradient accumulate, globa_bsz: {args.bsz} micro_bsz: {micro_bsz}" acc_steps = bsz_per_dp // micro_bsz # sharding \ model parallel \ pipeline assert dist_strategy.sharding == True dist_strategy.sharding_configs = { "segment_broadcast_MB": 32, "sharding_degree": args.num_sharding, "mp_degree": args.num_mp, "pp_degree": args.num_pp, "dp_degree": args.num_dp, "optimize_offload": True, } dist_strategy.pipeline_configs = { "schedule_mode": "1F1B", "micro_batch_size": micro_bsz, "accumulate_steps": acc_steps, } log.info( f"using globa_bsz: {args.global_bsz} micro_bsz: {micro_bsz}, acc_steps: {acc_steps}" ) dist_strategy.amp = args.use_amp dist_strategy.amp_configs = { "custom_white_list": ['softmax', 'layer_norm', 'gelu'], "init_loss_scaling": 32768, "decr_every_n_nan_or_inf": 2, "incr_every_n_steps": 1000, "incr_ratio": 2.0, "use_dynamic_loss_scaling": True, "decr_ratio": 0.5, "use_pure_fp16": False, "use_fp16_guard": False, } dist_strategy.lamb = args.use_lamb dist_strategy.lamb_configs = { 'lamb_weight_decay': 0.01, 'exclude_from_weight_decay': ['layer_norm_bias', 'layer_norm_scale', '.b_0'] } train_program = fluid.Program() startup_program = fluid.Program() with fluid.program_guard(train_program, startup_program): with fluid.unique_name.guard(): graph_vars = create_model(args, 'train', micro_bsz, dp_sharding_rank, dp_worldsize, topo) data_loader = graph_vars['data_loader'] for op in train_program.global_block().ops: if op.type == 'fill_constant': op._set_attr( 'op_device', "gpu:0" ) # XXX: hack: https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/layers/tensor.py#L1376 if args.use_recompute: dist_strategy.recompute_configs = { "checkpoints": graph_vars['checkpoints'], # "enable_offload": args.use_offload, # "checkpoint_shape": [micro_bsz, args.max_seq_len, 4096], } log.debug("base lr: {}".format(args.learning_rate)) scheduled_lr = linear_warmup_decay( learning_rate=args.learning_rate, warmup_steps=args.warmup_steps, num_train_steps=args.num_train_steps) clip_norm_thres = 1.0 if paddlenlp.ops.optimizer._jit_compile(): optimizer = paddlenlp.ops.optimizer.AdamwOptimizer( learning_rate=scheduled_lr, grad_clip=fluid.clip.GradientClipByGlobalNorm( clip_norm=clip_norm_thres), weight_decay=args.weight_decay, apply_decay_param_fun=apply_weight_decay_fun) else: optimizer = fluid.optimizer.Adam( learning_rate=scheduled_lr, grad_clip=fluid.clip.GradientClipByGlobalNorm( clip_norm=clip_norm_thres), #multi_precision=True, #weight_decay=args.weight_decay, # merge this pr to use weight_decay: https://github.com/PaddlePaddle/Paddle/pull/29248 #exclude_from_weight_decay_fn=exclude_from_weight_decay ) optimizer = fleet.distributed_optimizer(optimizer, dist_strategy) log.info(f"using dist strategy: {dist_strategy}") optimizer.minimize(graph_vars['total_loss']) final_strategy = fleet._final_strategy() applied_meta_list = fleet._get_applied_meta_list() log.info("final strategy: {}".format(final_strategy)) log.info("applied_meta_list: {}".format(applied_meta_list)) program_desc_dir = os.path.join(args.output_dir, "program_desc") if not os.path.isdir(program_desc_dir): os.mkdir(program_desc_dir) with open( program_desc_dir + "/main_program.txt.%d" % (int(os.environ.get('FLAGS_selected_gpus', 0))), 'w') as f: f.write(str(train_program)) with open( program_desc_dir + "/startup_program.txt.%d" % (int(os.environ.get('FLAGS_selected_gpus', 0))), 'w') as f: f.write(str(startup_program)) exe = fluid.Executor(place) exe.run(startup_program) optimizer.amp_init(place) #save_path = os.path.join(args.output_dir, 'step_0') #log.debug("saving models to {}".format(save_path)) #save_persistables(exe, save_path, train_program) if args.init_checkpoint and args.init_checkpoint != "": log.info(' ') log.info( '############################WARNING############################') log.info( '####### using ini_checkpoint, not init_pretraining_params ####') log.info( '## meaning hyper param e.g. lr will inherit from checkpoint ##') log.info( '###############################################################') init_checkpoint(exe, args.init_checkpoint, train_program) log.info(' ') output_dir = args.output_dir save_steps = args.save_steps total_time = 0 cost_vals, lm_losses, sop_accs = [], [], [] global_steps = args.global_steps + 1 steps = 0 log_path = 'train_log/node-%d' % fleet.worker_index() start_time = time.time() with LogWriter(os.path.join(args.output_dir, log_path)) as swriter: data_loader.start() while True: #if steps < global_steps: # steps += 1 # continue if not is_last: fetch_list = [] else: fetch_list = [ graph_vars['total_loss'], graph_vars['mean_mask_lm_loss'], scheduled_lr ] if args.use_sop: fetch_list.extend( [graph_vars['sop_acc'], graph_vars['sop_loss']]) if args.use_amp: loss_scaling = train_program.global_block( ).vars['loss_scaling_0'] fetch_list.append(loss_scaling) ret = exe.run(train_program, fetch_list=fetch_list ) # run one mini-batch(=acc_steps micro-batch) #use_program_cache=True) steps += 1 if is_last: if args.use_sop and args.use_amp: cost_val, lm_loss, lr, sop_acc, sop_loss, loss_scaling_0 = ret elif args.use_sop: cost_val, lm_loss, lr, sop_acc, sop_loss = ret elif args.use_amp: cost_val, lm_loss, lr, loss_scaling_0 = ret else: cost_val, lm_loss, lr = ret cost_vals.append(cost_val[0]) lm_losses.append(lm_loss[0]) if args.use_sop: sop_accs.append(sop_acc[0]) if steps > 0 and (steps % args.log_steps) == 0: end_time = time.time() total_time = end_time - start_time cost_val = np.mean(cost_vals) lm_loss = np.mean(lm_losses) swriter.add_scalar('loss/total_loss', cost_val, steps) swriter.add_scalar('loss/mlm_loss', lm_loss, steps) swriter.add_scalar('lr/scheduled_lr', lr[0], steps) if args.use_sop: sop_acc = np.mean(sop_accs) swriter.add_scalar('loss/sop_loss', sop_loss, steps) swriter.add_scalar('train/sop_acc', sop_acc, steps) else: sop_acc = 0.0 if args.use_amp: swriter.add_scalar('lr/loss_scaling', loss_scaling_0[0], steps) else: loss_scaling_0 = [0.0] log.info( "worker_index: %d, step: %d, cost: %f, " "mlm loss: %f, sentence order acc: %f, " "speed: %f steps/s, " "speed: %f samples/s, " "speed: %f tokens/s, " "learning rate: %.3e, loss_scalings: %f" % (fleet.worker_index(), steps, cost_val, lm_loss, sop_acc, args.log_steps / total_time, args.log_steps * args.global_bsz / total_time, args.log_steps * args.global_bsz * args.max_seq_len / total_time, lr[0], loss_scaling_0[0])) cost_vals, lm_losses, sop_accs = [], [], [] start_time = time.time() # TODO: add evaluation if steps > 0 and args.eval_steps > 0 and steps % args.eval_steps == 0: pass if steps > 0 and args.save_steps > 0 and steps % args.save_steps == 0: if args.use_hybrid_dp and fleet.worker_index() > 8: continue save_path = os.path.join(output_dir, 'step_' + str(steps)) log.debug("saving models to {}".format(save_path)) save_persistables(exe, save_path, train_program) if steps == args.num_train_steps: if args.use_hybrid_dp and fleet.worker_index() > 8: continue save_path = os.path.join(output_dir, 'final_step_' + str(steps)) save_persistables(exe, save_path, train_program) log.debug("saving final models to {}".format(save_path)) log.debug("end of training, total steps: {}".format(steps))