def train(data_train, data_eval, model, nsp_loss, mlm_loss, vocab_size, ctx): """Training function.""" hvd.broadcast_parameters(model.collect_params(), root_rank=0) mlm_metric = nlp.metric.MaskedAccuracy() nsp_metric = nlp.metric.MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() logging.debug('Creating distributed trainer...') lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True dynamic_loss_scale = args.dtype == 'float16' if dynamic_loss_scale: loss_scale_param = {'scale_window': 2000 / num_workers} else: loss_scale_param = None trainer = hvd.DistributedTrainer(model.collect_params(), 'bertadam', optim_params) fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale, loss_scaler_params=loss_scale_param) if args.start_step: state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank)) logging.info('Loading trainer state from %s', state_path) nlp.utils.load_states(trainer, state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [p for p in model.collect_params().values() if p.grad_req != 'null'] param_dict = model.collect_params() # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if accumulate > 1: for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() running_mlm_loss, running_nsp_loss = 0, 0 running_num_tks = 0 batch_num = 0 step_num = args.start_step logging.debug('Training started') while step_num < num_train_steps: for _, dataloader in enumerate(data_train): if step_num >= num_train_steps: break # create dummy data loader if needed if args.dummy_data_len: target_shape = (args.batch_size, args.dummy_data_len) dataloader = get_dummy_dataloader(dataloader, target_shape) for _, data_batch in enumerate(dataloader): if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # if accumulate > 1, grad_req is set to 'add', and zero_grad is required if accumulate > 1: param_dict.zero_grad() # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = lr * step_num / num_train_steps new_lr = lr - offset trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 14, profile_name=args.profile + str(rank)) # load data if args.use_avg_len: data_list = [[seq.as_in_context(context) for seq in shard] for context, shard in zip([ctx], data_batch)] else: data_list = list(split_and_load(data_batch, [ctx])) data = data_list[0] # forward with mx.autograd.record(): (ls, ns_label, classified, masked_id, decoded, \ masked_weight, ls1, ls2, valid_len) = forward(data, model, mlm_loss, nsp_loss, vocab_size, args.dtype) ls = ls / accumulate # backward if args.dtype == 'float16': fp16_trainer.backward(ls) else: ls.backward() running_mlm_loss += ls1.as_in_context(mx.cpu()) running_nsp_loss += ls2.as_in_context(mx.cpu()) running_num_tks += valid_len.sum().as_in_context(mx.cpu()) # update if (batch_num + 1) % accumulate == 0: # step() performs 3 things: # 1. allreduce gradients from all workers # 2. checking the global_norm of gradients and clip them if necessary # 3. averaging the gradients and apply updates fp16_trainer.step(1, max_norm=1*num_workers) nsp_metric.update([ns_label], [classified]) mlm_metric.update([masked_id], [decoded], [masked_weight]) # logging if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0: log(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer, args.log_interval) begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 mlm_metric.reset_local() nsp_metric.reset_local() # saving checkpoints if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0: if is_master_node: save_states(step_num, trainer, args.ckpt_dir, local_rank) if local_rank == 0: save_parameters(step_num, model, args.ckpt_dir) if data_eval: # eval data is always based on a fixed npz file. dataset_eval = get_pretrain_data_npz(data_eval, args.batch_size_eval, 1, False, False, 1) evaluate(dataset_eval, model, nsp_loss, mlm_loss, len(vocab), [ctx], args.log_interval, args.dtype) batch_num += 1 if is_master_node: save_states(step_num, trainer, args.ckpt_dir, local_rank) if local_rank == 0: save_parameters(step_num, model, args.ckpt_dir) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
def train(data_train, model, nsp_loss, mlm_loss, vocab_size, ctx, store): """Training function.""" mlm_metric = nlp.metric.MaskedAccuracy() nsp_metric = nlp.metric.MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True trainer = mx.gluon.Trainer(model.collect_params(), 'bertadam', optim_params, update_on_kvstore=False, kvstore=store) dynamic_loss_scale = args.dtype == 'float16' fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale) if args.start_step: state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d' % (args.start_step, 0)) logging.info('Loading trainer state from %s', state_path) nlp.utils.load_states(trainer, state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [ p for p in model.collect_params().values() if p.grad_req != 'null' ] param_dict = model.collect_params() # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if accumulate > 1: for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 batch_num = 0 step_num = args.start_step parallel_model = ParallelBERT(model, mlm_loss, nsp_loss, vocab_size, store.num_workers * accumulate, trainer=fp16_trainer) num_ctxes = len(ctx) parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model) while step_num < num_train_steps: for _, dataloader in enumerate(data_train): if step_num >= num_train_steps: break # create dummy data loader if needed if args.dummy_data_len: target_shape = (args.batch_size * num_ctxes, args.dummy_data_len) dataloader = get_dummy_dataloader(dataloader, target_shape) for _, data_batch in enumerate(dataloader): if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # if accumulate > 1, grad_req is set to 'add', and zero_grad is required if accumulate > 1: param_dict.zero_grad() # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = lr * step_num / num_train_steps new_lr = lr - offset trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 12, profile_name=args.profile) if args.use_avg_len: data_list = [[seq.as_in_context(context) for seq in shard] for context, shard in zip(ctx, data_batch)] else: if data_batch[0].shape[0] < len(ctx): continue data_list = split_and_load(data_batch, ctx) ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] # parallel forward / backward for data in data_list: parallel.put(data) for _ in range(len(ctx)): (_, next_sentence_label, classified, masked_id, decoded, masked_weight, ls1, ls2, valid_length) = parallel.get() ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) running_mlm_loss += ls1.as_in_context(mx.cpu()) / num_ctxes running_nsp_loss += ls2.as_in_context(mx.cpu()) / num_ctxes running_num_tks += valid_length.sum().as_in_context( mx.cpu()) # update if (batch_num + 1) % accumulate == 0: fp16_trainer.step(1, max_norm=1) nsp_metric.update(ns_label_list, ns_pred_list) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # logging if (step_num + 1) % (args.log_interval) == 0 and ( batch_num + 1) % accumulate == 0: log(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer, args.log_interval) begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 mlm_metric.reset_local() nsp_metric.reset_local() # saving checkpoints if (step_num + 1) % args.ckpt_interval == 0 \ and (batch_num + 1) % accumulate == 0 and store.rank == 0: save_states(step_num, trainer, args.ckpt_dir) save_parameters(step_num, model, args.ckpt_dir) batch_num += 1 if store.rank == 0: save_states(step_num, trainer, args.ckpt_dir) save_parameters(step_num, model, args.ckpt_dir) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
def train(data_train, data_eval, model): """Training function.""" # backend specific implementation param_dict = model.bert.collect_params() if backend == 'horovod': hvd.broadcast_parameters(param_dict, root_rank=0) mlm_metric = nlp.metric.MaskedAccuracy() nsp_metric = nlp.metric.MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() logging.info('Creating distributed trainer...') lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True dynamic_loss_scale = args.dtype == 'float16' if dynamic_loss_scale: loss_scale_param = { 'scale_window': 2000 / num_workers, 'init_scale': 2**10 } else: loss_scale_param = None # backend specific implementation if backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params, update_on_kvstore=False) fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale, loss_scaler_params=loss_scale_param) if args.start_step: state_path = os.path.join( args.ckpt_dir, '%07d.states.%02d' % (args.start_step, local_rank)) logging.info('Loading trainer state from %s', state_path) nlp.utils.load_states(trainer, state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [p for p in param_dict.values() if p.grad_req != 'null'] # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if accumulate > 1: for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() running_mlm_loss, running_nsp_loss = 0, 0 running_num_tks = 0 batch_num = 0 step_num = args.start_step if args.phase2: step_num -= args.phase1_num_steps logging.info('Training started') # create dummy data loader if needed parallel_model = DataParallelBERT(model, trainer=fp16_trainer) num_ctxes = len(ctxs) parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model) while step_num < num_train_steps: data_train_iter = iter(data_train) end_of_batch = False next_data_batch = next(data_train_iter) while not end_of_batch: data_batch = next_data_batch if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = (num_train_steps - step_num) / (num_train_steps - num_warmup_steps) new_lr = lr * max(offset, 0) trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 14, profile_name=args.profile + str(rank)) # load data data_list = list(split_and_load(data_batch, ctxs)) ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] num_data = len(data_list) for i in range(num_data): parallel.put(data_list[i]) for _ in range(num_data): (next_sentence_label, classified, masked_id, decoded, masked_weight, ls1, ls2, valid_length) = parallel.get() ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) running_mlm_loss += ls1.as_in_context(mx.cpu()) / len(ctxs) running_nsp_loss += ls2.as_in_context(mx.cpu()) / len(ctxs) running_num_tks += valid_length.sum().as_in_context(mx.cpu()) # pre fetch next batch try: next_data_batch = next(data_train_iter) except StopIteration: end_of_batch = True # update if (batch_num + 1) % accumulate == 0: fp16_trainer.step(1, max_norm=1.0 * num_workers) if accumulate > 1: param_dict.zero_grad() # update metrics if args.no_compute_acc: mask_pred_list[0].wait_to_read() else: nsp_metric.update(ns_label_list, ns_pred_list) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # logging if step_num % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0: if args.no_compute_acc: log_noacc(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, trainer, args.log_interval) else: log(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer, args.log_interval) mlm_metric.reset_local() nsp_metric.reset_local() begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 # saving checkpoints if step_num % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0: if is_master_node: save_states(step_num, trainer, args.ckpt_dir, local_rank) if local_rank == 0: save_parameters(step_num, model.bert, args.ckpt_dir) if step_num % args.eval_interval == 0 and data_eval \ and (batch_num + 1) % accumulate == 0: # eval data is always based on a fixed npz file. dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval, 1, False, 1, vocab) evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype) batch_num += 1 if is_master_node: save_states(step_num, trainer, args.ckpt_dir, local_rank) if local_rank == 0: save_parameters(step_num, model.bert, args.ckpt_dir) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))
def train(data_train, data_eval, model): """Training function.""" # backend specific implementation param_dict = model.bert.collect_params() if backend == 'horovod': hvd.broadcast_parameters(param_dict, root_rank=0) mlm_metric = nlp.metric.MaskedAccuracy() nsp_metric = nlp.metric.MaskedAccuracy() mlm_metric.reset() nsp_metric.reset() logging.debug('Creating distributed trainer...') lr = args.lr optim_params = {'learning_rate': lr, 'epsilon': 1e-6, 'wd': 0.01} if args.dtype == 'float16': optim_params['multi_precision'] = True if args.optimizer == 'lamb': optim_params['bias_correction'] = True dynamic_loss_scale = args.dtype == 'float16' if dynamic_loss_scale: loss_scale_param = {'scale_window': 2000 / num_workers, 'init_scale': 1} else: loss_scale_param = None # backend specific implementation if backend == 'horovod': trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optim_params) elif backend == 'byteps': trainer = bps.DistributedTrainer(param_dict, args.optimizer, optim_params) else: trainer = mx.gluon.Trainer(param_dict, args.optimizer, optim_params, update_on_kvstore=False) fp16_trainer = FP16Trainer(trainer, dynamic_loss_scale=dynamic_loss_scale, loss_scaler_params=loss_scale_param) if args.start_step: state_path = os.path.join(args.ckpt_dir, '%07d.states.%02d'%(args.start_step, local_rank)) logging.info('Loading trainer state from %s', state_path) nlp.utils.load_states(trainer, state_path) accumulate = args.accumulate num_train_steps = args.num_steps warmup_ratio = args.warmup_ratio num_warmup_steps = int(num_train_steps * warmup_ratio) params = [p for p in param_dict.values() if p.grad_req != 'null'] # Do not apply weight decay on LayerNorm and bias terms for _, v in model.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 if accumulate > 1: for p in params: p.grad_req = 'add' train_begin_time = time.time() begin_time = time.time() running_mlm_loss, running_nsp_loss = 0, 0 local_mlm_loss, local_num_masks = 0, mx.nd.array([0], ctx=ctxs[0]) running_num_tks = 0 batch_num = 0 step_num = args.start_step logging.debug('Training started') logging.info('Generating the first batch of data, which may take a few minutes ...') # create dummy data loader if needed parallel_model = DataParallelBERT(model, trainer=fp16_trainer) num_ctxes = len(ctxs) parallel = nlp.utils.Parallel(num_ctxes if num_ctxes > 1 else 0, parallel_model) if backend == 'byteps': bps.byteps_declare_tensor("local_num_masks") bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0) logging.debug('Broadcast local_num_masks tensor') next_batch = next(iter(get_dummy_dataloader(batch_size, args.max_seq_length, args.max_predictions_per_seq))) data_list = list(split_and_load(next_batch, ctxs)) parallel.put(data_list[0]) parallel.get() trainer._init_params() while step_num < num_train_steps: data_train_iter = iter(data_train) end_of_batch = False next_data_batch = next(data_train_iter) while not end_of_batch: data_batch = next_data_batch if step_num >= num_train_steps: break if batch_num % accumulate == 0: step_num += 1 # if accumulate > 1, grad_req is set to 'add', and zero_grad is required if accumulate > 1: param_dict.zero_grad() # update learning rate if step_num <= num_warmup_steps: new_lr = lr * step_num / num_warmup_steps else: offset = lr * step_num / num_train_steps new_lr = lr - offset trainer.set_learning_rate(new_lr) if args.profile: profile(step_num, 10, 14, profile_name=args.profile + str(rank)) if early_stop and step_num == 10: mx.nd.waitall() exit() # load data data_list = list(split_and_load(data_batch, ctxs)) ns_label_list, ns_pred_list = [], [] mask_label_list, mask_pred_list, mask_weight_list = [], [], [] with mx.autograd.record(): num_data = len(data_list) for i in range(num_data): parallel.put(data_list[i]) for _ in range(num_data): (next_sentence_label, classified, masked_id, decoded, masked_weight, ls1, ls2, valid_length, num_masks) = parallel.get() ns_label_list.append(next_sentence_label) ns_pred_list.append(classified) mask_label_list.append(masked_id) mask_pred_list.append(decoded) mask_weight_list.append(masked_weight) local_num_masks += num_masks local_mlm_loss += ls1 running_num_tks += valid_length.sum() # pre fetch next batch try: next_data_batch = next(data_train_iter) except StopIteration: end_of_batch = True # update if (batch_num + 1) % accumulate == 0: running_mlm_loss += local_mlm_loss / local_num_masks if backend == 'horovod': hvd.allreduce_(local_num_masks, average=False, name='local_num_masks') elif backend == 'byteps': bps.byteps_push_pull(local_num_masks, is_average=False, name="local_num_masks", priority=0) # because byteps implicitly set scale /= num_workers fp16_trainer.step(local_num_masks * num_workers, max_norm=local_num_masks, num_ctxs=len(ctxs) * num_workers) local_num_masks, local_mlm_loss = 0, 0 # update metrics if args.no_compute_acc: for mask_pred_i in mask_pred_list: mask_pred_i.wait_to_read() else: nsp_metric.update(ns_label_list, ns_pred_list) mlm_metric.update(mask_label_list, mask_pred_list, mask_weight_list) # logging if (step_num + 1) % (args.log_interval) == 0 and (batch_num + 1) % accumulate == 0: if args.no_compute_acc: log_noacc(begin_time, running_num_tks, running_mlm_loss, 0, step_num, trainer, args.log_interval) else: log(begin_time, running_num_tks, running_mlm_loss / accumulate, running_nsp_loss / accumulate, step_num, mlm_metric, nsp_metric, trainer, args.log_interval) mlm_metric.reset_local() nsp_metric.reset_local() begin_time = time.time() running_mlm_loss = running_nsp_loss = running_num_tks = 0 # saving checkpoints if (step_num + 1) % args.ckpt_interval == 0 and (batch_num + 1) % accumulate == 0: # if is_master_node: # save_states(step_num, trainer, args.ckpt_dir, local_rank) # if local_rank == 0: # save_parameters(step_num, model.bert, args.ckpt_dir) if (step_num + 1) % args.eval_interval == 0 and data_eval: # eval data is always based on a fixed npz file. dataset_eval = get_pretrain_data_npz(data_eval, batch_size_eval, 1, False, 1, vocab) evaluate(dataset_eval, model, ctxs, args.log_interval, args.dtype, rank, num_workers) batch_num += 1 # if is_master_node: # save_states(step_num, trainer, args.ckpt_dir, local_rank) # if local_rank == 0: # save_parameters(step_num, model, args.ckpt_dir) mx.nd.waitall() train_end_time = time.time() logging.info('Train cost={:.1f}s'.format(train_end_time - train_begin_time))