def validate(model, valid_src, valid_tgt, toker, vocab, device, local_rank): model.eval() val_loss = 0 n_correct = 0 n_word = 0 with open(valid_src, 'r') as src_reader, \ open(valid_tgt, 'r') as tgt_reader: for i, (src, tgt) in enumerate(zip(src_reader, tgt_reader)): if local_rank != -1: global_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() if global_rank % world_size != 0: continue input_ids, type_ids, mask, labels = convert_raw_input_to_features( src, tgt, toker, vocab, device) prediction_scores = model(input_ids, type_ids, mask) loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='sum') loss = loss_fct(prediction_scores.squeeze(0), labels.view(-1)) val_loss += loss.item() n_correct += accuracy_count(prediction_scores, labels) n_word += (labels != -1).long().sum().item() if local_rank != -1: val_loss = sum(all_gather_list(val_loss)) n_correct = sum(all_gather_list(n_correct)) n_word = sum(all_gather_list(n_word)) val_loss /= n_word acc = n_correct / n_word val_log = {'val_loss': val_loss, 'val_acc': acc} model.train() return val_log
def all_gather_stats_list(stat_list, max_size=4096): """ Gather a `Statistics` list accross all processes/nodes Args: stat_list(list([`Statistics`])): list of statistics objects to gather accross all processes/nodes max_size(int): max buffer size to use Returns: our_stats(list([`Statistics`])): list of updated stats """ from torch.distributed import get_rank from onmt.utils.distributed import all_gather_list # Get a list of world_size lists with len(stat_list) Statistics objects all_stats = all_gather_list(stat_list, max_size=max_size) our_rank = get_rank() our_stats = all_stats[our_rank] for other_rank, stats in enumerate(all_stats): if other_rank == our_rank: continue for i, stat in enumerate(stats): our_stats[i].update(stat, update_n_src_words=True) return our_stats
def all_gather_stats_list(stat_list, max_size=4096): """ Gather a `Statistics` list accross all processes/nodes Args: stat_list(list([`Statistics`])): list of statistics objects to gather accross all processes/nodes max_size(int): max buffer size to use Returns: our_stats(list([`Statistics`])): list of updated stats """ # apply https://github.com/OpenNMT/OpenNMT-py/commit/2a2621d770adb593942d7999a59401aff35d646a from torch.distributed import get_rank from onmt.utils.distributed import all_gather_list # Get a list of world_size lists with len(stat_list) Statistics objects all_stats = all_gather_list(stat_list, max_size=max_size) our_rank = get_rank() our_stats = all_stats[our_rank] for other_rank, stats in enumerate(all_stats): if other_rank == our_rank: continue for i, stat in enumerate(stats): our_stats[i].update(stat, update_n_src_words=True) return our_stats
def _norm(self, batch): if self._norm_method == "tokens": norm = batch.tgt[1:].ne(self._train_loss.padding_idx).sum() else: norm = batch.batch_size if self.n_gpu > 1: norm = sum(all_gather_list(norm)) return norm
def main(opts): if opts.local_rank == -1: assert torch.cuda.is_available() device = torch.device("cuda") n_gpu = 1 else: torch.cuda.set_device(opts.local_rank) device = torch.device("cuda", opts.local_rank) # Initializes the distributed backend which will take care of # sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') n_gpu = torch.distributed.get_world_size() logger.info("device: {} n_gpu: {}, distributed training: {}, " "16-bits training: {}".format( device, n_gpu, bool(opts.local_rank != -1), opts.fp16)) opts.n_gpu = n_gpu if opts.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " "should be >= 1".format( opts.gradient_accumulation_steps)) is_master = opts.local_rank == -1 or torch.distributed.get_rank() == 0 if is_master: save_training_meta(opts) random.seed(opts.seed) np.random.seed(opts.seed) torch.manual_seed(opts.seed) if n_gpu > 0: torch.cuda.manual_seed_all(opts.seed) tokenizer = BertTokenizer.from_pretrained( opts.bert_model, do_lower_case='uncased' in opts.bert_model) # train_examples = None print("Loading Train Dataset", opts.train_file) vocab_dump = torch.load(opts.vocab_file) vocab = vocab_dump['tgt'].fields[0][1].vocab.stoi train_dataset = BertDataset(opts.train_file, tokenizer, vocab, seq_len=opts.max_seq_length, max_len=opts.max_sent_length) # Prepare model model = BertForSeq2seq.from_pretrained(opts.bert_model) embedding = convert_embedding( tokenizer, vocab, model.bert.embeddings.word_embeddings.weight) model.update_output_layer(embedding) if opts.fp16: model.half() model.to(device) if opts.local_rank != -1: # need to make sure models are the same in the beginning params = [p.data for p in model.parameters()] broadcast_tensors(params) for name, module in model.named_modules(): # we might want to tune dropout for smaller dataset if isinstance(module, torch.nn.Dropout): module.p = opts.dropout # Prepare optimizer param_optimizer = [(n, p) for n, p in model.named_parameters() if 'pooler' not in n] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] if opts.fp16: try: from apex.optimizers import FP16_Optimizer from apex.optimizers import FusedAdam except ImportError: raise ImportError("Please install apex from " "https://www.github.com/nvidia/apex " "to use distributed and fp16 training.") optimizer = FusedAdam(optimizer_grouped_parameters, lr=opts.learning_rate, bias_correction=False, max_grad_norm=1.0) if opts.loss_scale == 0: optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer(optimizer, static_loss_scale=opts.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=opts.learning_rate, warmup=opts.warmup_proportion, t_total=opts.num_train_steps) global_step = 0 logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Batch size = %d", opts.train_batch_size) logger.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) logger.info(" Num steps = %d", opts.num_train_steps) if opts.local_rank == -1: train_sampler = TokenBucketSampler( train_dataset.lens, bucket_size=8192, batch_size=opts.train_batch_size, droplast=True) train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=4, collate_fn=BertDataset.pad_collate) else: train_sampler = DistributedTokenBucketSampler( n_gpu, opts.local_rank, train_dataset.lens, bucket_size=8192, batch_size=opts.train_batch_size, droplast=True) train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=4, collate_fn=BertDataset.pad_collate) if is_master: TB_LOGGER.create(join(opts.output_dir, 'log')) running_loss = RunningMeter('loss') model.train() if is_master: pbar = tqdm(total=opts.num_train_steps) else: logger.disabled = True pbar = None n_examples = 0 n_epoch = 0 start = time() while True: for step, batch in enumerate(train_dataloader): batch = tuple(t.to(device) if t is not None else t for t in batch) input_ids, input_mask, segment_ids, lm_label_ids = batch n_examples += input_ids.size(0) mask = lm_label_ids != -1 loss = model(input_ids, segment_ids, input_mask, lm_label_ids, mask, True) if opts.fp16: optimizer.backward(loss) else: loss.backward() running_loss(loss.item()) if (step + 1) % opts.gradient_accumulation_steps == 0: global_step += 1 if opts.fp16: # modify learning rate with special warm up BERT uses # if opts.fp16 is False, BertAdam is used that handles # this automatically lr_this_step = opts.learning_rate * warmup_linear( global_step/opts.num_train_steps, opts.warmup_proportion) if lr_this_step < 0: # save guard for possible miscalculation of train steps lr_this_step = 1e-8 for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step TB_LOGGER.add_scalar('lr', lr_this_step, global_step) # NOTE running loss not gathered across GPUs for speed TB_LOGGER.add_scalar('loss', running_loss.val, global_step) TB_LOGGER.step() if opts.local_rank != -1: # gather gradients from every processes grads = [p.grad.data for p in model.parameters() if p.requires_grad and p.grad is not None] all_reduce_and_rescale_tensors(grads, float(1)) optimizer.step() optimizer.zero_grad() if pbar is not None: pbar.update(1) if global_step % 5 == 0: torch.cuda.empty_cache() if global_step % 100 == 0: if opts.local_rank != -1: total = sum(all_gather_list(n_examples)) else: total = n_examples if is_master: ex_per_sec = int(total / (time()-start)) logger.info(f'{total} examples trained at ' f'{ex_per_sec} ex/s') TB_LOGGER.add_scalar('ex_per_s', ex_per_sec, global_step) if global_step % opts.valid_steps == 0: logger.info(f"start validation at Step {global_step}") with torch.no_grad(): val_log = validate(model, opts.valid_src, opts.valid_tgt, tokenizer, vocab, device, opts.local_rank) logger.info(f"Val Acc: {val_log['val_acc']}; " f"Val Loss: {val_log['val_loss']}") TB_LOGGER.log_scaler_dict(val_log) if is_master: output_model_file = join( opts.output_dir, 'ckpt', f"model_step_{global_step}.pt") # save cpu checkpoint state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.state_dict().items()} torch.save(state_dict, output_model_file) if global_step >= opts.num_train_steps: break if global_step >= opts.num_train_steps: break n_epoch += 1 if is_master: logger.info(f"finished {n_epoch} epochs") if opts.num_train_steps % opts.valid_steps != 0: with torch.no_grad(): val_log = validate(model, opts.valid_src, opts.valid_tgt, tokenizer, vocab, device, opts.local_rank) TB_LOGGER.log_scaler_dict(val_log) if is_master: output_model_file = join(opts.output_dir, 'ckpt', f"model_step_{global_step}.pt") # save cpu checkpoint state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.state_dict().items()} torch.save(model.state_dict(), output_model_file)
def train(self, train_iter_fct, train_steps, valid_iter_fct=None, valid_steps=-1): """ The main training loops. by iterating over training data (i.e. `train_iter_fct`) and running validation (i.e. iterating over `valid_iter_fct` Args: train_iter_fct(function): a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs) valid_iter_fct(function): same as train_iter_fct, for valid data train_steps(int): valid_steps(int): save_checkpoint_steps(int): Return: None """ step = self.optim.training_step + 1 true_batchs = [] accum = 0 normalization = 0 train_iter = train_iter_fct() total_stats, report_stats = Statistics(), Statistics() self._start_report_manager(start_time=total_stats.start_time) while step <= train_steps: reduce_counter = 0 for i, batch in enumerate(train_iter): # Pick batch for this GPU if applicable if self._should_skip_batch(i): continue _, batch_size = batch true_batchs.append(batch) normalization += batch_size accum += 1 if accum == self.grad_accum_count: reduce_counter += 1 if self.n_gpu > 1: normalization = sum( distributed.all_gather_list(normalization)) self._gradient_accumulation(true_batchs, normalization, total_stats, report_stats, step) report_stats = self._maybe_report_training( step, train_steps, self.optim.learning_rate(), report_stats) true_batchs = [] accum = 0 normalization = 0 # Save checkpoint if applicable on one GPU if (step % self.save_checkpoint_steps == 0 and self.gpu_rank == 0): self._save(step) step += 1 # Break loop if hit training step limit if step > train_steps: break # Reset dataset iterator if necessary train_iter = train_iter_fct() return total_stats