def training_step(self, batch, batch_idx): tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = self.process_batch( batch) if not self.cfg.bert_binary_head: types = None output_tensor = self(tokens, padding_mask, token_type_ids=types, lm_labels=lm_labels) loss_dict = self.loss_func(loss_mask, sentence_order, output_tensor) if 'sop loss' in loss_dict: lm_loss = loss_dict['lm loss'] sop_loss = loss_dict['sop loss'] loss = lm_loss + sop_loss reduced_loss = average_losses_across_data_parallel_group( [loss, lm_loss, sop_loss]) self._reduced_loss_buffer.append(reduced_loss[0]) self._reduced_lm_loss_buffer.append(reduced_loss[1]) self._reduced_sop_loss_buffer.append(reduced_loss[2]) else: lm_loss = loss_dict['lm loss'] loss = lm_loss reduced_loss = average_losses_across_data_parallel_group( [loss, lm_loss]) self._reduced_loss_buffer.append(reduced_loss[0]) self._reduced_lm_loss_buffer.append(reduced_loss[1]) if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: # Reduced loss for logging. average_reduced_loss = sum(self._reduced_loss_buffer) / len( self._reduced_loss_buffer) self.log('reduced_train_loss', average_reduced_loss, prog_bar=True) if len(self._reduced_sop_loss_buffer) > 0: average_reduced_lm_loss = sum( self._reduced_lm_loss_buffer) / len( self._reduced_lm_loss_buffer) average_reduced_sop_loss = sum( self._reduced_sop_loss_buffer) / len( self._reduced_sop_loss_buffer) self.log('reduced_lm_train_loss', average_reduced_lm_loss, prog_bar=True) self.log('reduced_sop_train_loss', average_reduced_sop_loss, prog_bar=True) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr) self.log('global_step', self.trainer.global_step, prog_bar=True) self.log( 'consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step), prog_bar=True, ) self._reduced_loss_buffer = [] self._reduced_lm_loss_buffer = [] self._reduced_sop_loss_buffer = [] return loss
def test_epoch_end(self, outputs): averaged_loss = average_losses_across_data_parallel_group(outputs) logging.info(f'test_loss: {averaged_loss[0]}') self.log( 'consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step), ) return averaged_loss
def inference_epoch_end(self, outputs): losses = [x['loss'] for x in outputs] averaged_loss = average_losses_across_data_parallel_group(losses) all_preds = [] all_labels = [] for item in outputs: preds = item['predicted_token_ids'].cpu().numpy().tolist() labels = item['labels'].cpu().numpy().tolist() for i, (pred, label) in enumerate(zip(preds, labels)): if self.tokenizer.eos_id in pred: idx = pred.index(self.tokenizer.eos_id) pred = pred[:idx] pred = [ id for id in pred if id not in self.tokenizer.special_token_to_id.values() ] label = [ id for id in label if id not in self.tokenizer.special_token_to_id.values() ] pred = self.tokenizer.ids_to_text(pred) label = self.tokenizer.ids_to_text(label) all_preds.append(pred) all_labels.append(label) correct = 0 for pred, label in zip(all_preds, all_labels): if pred == label: correct += 1 acc = correct / len(all_preds) return averaged_loss[0], acc
def training_step(self, batch, batch_idx): input_tokens_id = batch['tokens'] input_attn_mask = batch['tokens_mask'] loss_mask = batch['loss_mask'] retrieved_ids = batch['retrieved_ids'] retrieved_attn_mask = batch['retrieved_emb_mask'] labels = batch['labels'] loss = self(input_tokens_id, input_attn_mask, retrieved_ids, retrieved_attn_mask, labels=labels) loss_mask = loss_mask.float() lm_loss = torch.sum(loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() reduced_loss = average_losses_across_data_parallel_group([lm_loss]) self._reduced_loss_buffer.append(reduced_loss[0]) if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: # Reduced loss for logging. average_reduced_loss = sum(self._reduced_loss_buffer) / len(self._reduced_loss_buffer) self.log('reduced_train_loss', average_reduced_loss, prog_bar=True) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr) self.log('global_step', self.trainer.global_step, prog_bar=True) self.log( 'consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step), prog_bar=True, ) self._reduced_loss_buffer = [] return lm_loss
def training_step(self, batch, batch_idx): tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask = self.process_batch( batch) tokens_loss = itemgetter("tokens_loss")(self.model( tokens_enc, tokens_dec, enc_mask, dec_mask, tokentype_ids=None, lm_labels=labels, )) loss = self.model.loss_func(loss_mask, tokens_loss) self.log('train_loss', loss) # Reduced loss for logging. reduced_loss = average_losses_across_data_parallel_group([loss]) # cache reduced loss while accumulating gradients self.model._reduced_loss_buffer.append(reduced_loss[0]) if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: # Reduced loss for logging. average_reduced_loss = sum(self.model._reduced_loss_buffer) / len( self.model._reduced_loss_buffer) self.log('reduced_train_loss', average_reduced_loss, prog_bar=True) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr) self.log('global_step', self.trainer.global_step, prog_bar=True) self.model._reduced_loss_buffer = [] return loss
def training_step(self, batch, batch_idx): input_ids, labels, loss_mask, position_ids, attention_mask, taskname_ids = batch output = self.forward(input_ids, position_ids, attention_mask, taskname_ids, labels, inference=False) output_tensor, encoder_hidden_states = output loss = self.frozen_model.loss_func(loss_mask, output_tensor) self.log('train_loss', loss) # Reduced loss for logging. reduced_loss = average_losses_across_data_parallel_group([loss]) # Cache reduced loss while accumulating gradients self._reduced_loss_buffer.append(reduced_loss[0]) if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: # Reduced loss for logging. average_reduced_loss = sum(self._reduced_loss_buffer) / len( self._reduced_loss_buffer) self.log('reduced_train_loss', average_reduced_loss, prog_bar=True) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr) self.log('global_step', self.trainer.global_step, prog_bar=True) self._reduced_loss_buffer = [] return loss
def training_step(self, batch, batch_idx): tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask = self.process_batch(batch) tokens_loss = itemgetter("tokens_loss")( self(tokens_enc, tokens_dec, enc_mask, dec_mask, tokentype_ids=None, lm_labels=labels,) ) loss = self.loss_func(loss_mask, tokens_loss) self.log('train_loss', loss) # Reduced loss for logging. This averages the loss across all workers unlike "loss" above which is specific to a DDP rank. reduced_loss = average_losses_across_data_parallel_group([loss]) # cache reduced loss while accumulating gradients self._reduced_loss_buffer.append(reduced_loss[0]) if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: # Reduced loss for logging. average_reduced_loss = sum(self._reduced_loss_buffer) / len(self._reduced_loss_buffer) self.log('reduced_train_loss', average_reduced_loss, prog_bar=True) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr) self.log('global_step', self.trainer.global_step, prog_bar=True) self.log( 'consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step), prog_bar=True, ) self._reduced_loss_buffer = [] return loss
def inference_epoch_end(self, outputs): losses = [x['loss'] for x in outputs] averaged_loss = average_losses_across_data_parallel_group(losses) val_acc = self.acc_metric.compute() self.log('validation_loss', averaged_loss) self.log('validation_acc', val_acc['acc']) self.acc_metric.reset() return averaged_loss[0], val_acc
def validation_step(self, batch, batch_idx): tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask = self.process_batch(batch) tokens_loss = itemgetter("tokens_loss")( self(tokens_enc, tokens_dec, enc_mask, dec_mask, tokentype_ids=None, lm_labels=labels,) ) loss = self.loss_func(loss_mask, tokens_loss) reduced_loss = average_losses_across_data_parallel_group([loss]) return reduced_loss
def inference_epoch_end(self, outputs): losses = [x['loss'] for x in outputs] averaged_loss = average_losses_across_data_parallel_group(losses) acc_result = self.acc_metrics.compute() self.log('validation_loss', averaged_loss) self.log('validation_acc', acc_result['acc']) for lang in self.cfg.eval_languages: self.log(f'{lang}_acc', acc_result[lang]) self.acc_metrics.reset() return averaged_loss[0], acc_result
def validation_step(self, batch, batch_idx): input_tokens_id = batch['tokens'] input_attn_mask = batch['tokens_mask'] loss_mask = batch['loss_mask'] retrieved_ids = batch['retrieved_ids'] retrieved_attn_mask = batch['retrieved_emb_mask'] labels = batch['labels'] loss = self(input_tokens_id, input_attn_mask, retrieved_ids, retrieved_attn_mask, labels=labels) loss_mask = loss_mask.float() lm_loss = torch.sum(loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() reduced_loss = average_losses_across_data_parallel_group([lm_loss]) return reduced_loss
def validation_step(self, batch, batch_idx): tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask = self.process_batch( batch) output_tensor, encoder_hidden_states = self(tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, tokentype_ids=None, lm_labels=labels) loss = self.loss_func(loss_mask, output_tensor) reduced_loss = average_losses_across_data_parallel_group([loss]) return reduced_loss
def training_step(self, batch, batch_idx): loss, _, _, _, _ = self.get_loss(batch) self.log('train_loss', loss) # Reduced loss for logging. reduced_loss = average_losses_across_data_parallel_group([loss]) # cache reduced loss while accumulating gradients self._reduced_loss_buffer.append(reduced_loss[0]) if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: # Reduced loss for logging. average_reduced_loss = sum(self._reduced_loss_buffer) / len(self._reduced_loss_buffer) self.log('reduced_train_loss', average_reduced_loss, prog_bar=True) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr) self.log('global_step', self.trainer.global_step, prog_bar=True) self._reduced_loss_buffer = [] return loss
def validation_step(self, batch, batch_idx): tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = self.process_batch( batch) if not self.cfg.bert_binary_head: types = None output_tensor = self(tokens, padding_mask, tokentype_ids=types, lm_labels=lm_labels) loss_dict = self.loss_func(loss_mask, sentence_order, output_tensor) if 'sop loss' in loss_dict: lm_loss = loss_dict['lm loss'] sop_loss = loss_dict['sop loss'] loss = lm_loss + sop_loss else: lm_loss = loss_dict['lm loss'] loss = lm_loss reduced_loss = average_losses_across_data_parallel_group([loss]) return reduced_loss
def validation_step(self, batch, batch_idx): tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch( batch) output_tensor = self(tokens, position_ids, attention_mask, labels) loss = self.loss_func(loss_mask, output_tensor) reduced_loss = average_losses_across_data_parallel_group([loss]) # TODO: add text generation # take the first k tokens and then generate text - compare with ground truth (won't look similar in general) """ k = num_context n = max_generate_length context_tokens = tokens[0:k] while k < n: output_tensor = self(context_tokens) next_token = sample(output_tensor) context_tokens.append(next_token) k += 1 """ return reduced_loss
def validation_step(self, batch, batch_idx): if self.use_soft_prompts: tokens, labels, prompt_tags, attention_mask, loss_mask, text_position_ids = batch tokens = tokens.to(self.device) labels = labels.to(self.device) attention_mask = attention_mask.to(self.device) loss_mask = loss_mask.to(self.device) text_position_ids = text_position_ids.to(self.device) output_tensor = self(tokens, text_position_ids, attention_mask, labels, prompt_tags) else: tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch( batch) output_tensor = self(tokens, position_ids, attention_mask, labels) loss = self.loss_func(loss_mask, output_tensor) reduced_loss = average_losses_across_data_parallel_group([loss]) return reduced_loss
def training_step(self, batch, batch_idx): tokens, labels, loss_mask, attention_mask, position_ids = self.process_batch( batch) output_tensor = self(tokens, position_ids, attention_mask, labels) loss = self.loss_func(loss_mask, output_tensor) reduced_loss = average_losses_across_data_parallel_group([loss]) # cache reduced loss while accumulating gradients self._reduced_loss_buffer.append(reduced_loss[0]) if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: # Reduced loss for logging. average_reduced_loss = sum(self._reduced_loss_buffer) / len( self._reduced_loss_buffer) self.log('reduced_train_loss', average_reduced_loss, prog_bar=True) lr = self._optimizer.param_groups[0]['lr'] self.log('lr', lr) self.log('global_step', self.trainer.global_step, prog_bar=True) self.log('consumed_samples', self.compute_consumed_samples(self.trainer.global_step), prog_bar=True) self._reduced_loss_buffer = [] return loss
def test_epoch_end(self, outputs): averaged_loss = average_losses_across_data_parallel_group(outputs) logging.info(f'test_loss: {averaged_loss[0]}')
def loss_func(output_tensor): loss = self.loss_func(loss_mask, output_tensor) reduced_loss = average_losses_across_data_parallel_group([loss]) return loss, {'avg': reduced_loss}
def eval_epoch_end(self, outputs, mode): if isinstance(outputs[0], dict): outputs = [outputs] loss_list = [] bleu_score_list = [] for dataloader_idx, output in enumerate(outputs): averaged_loss = average_losses_across_data_parallel_group( [x['loss'] for x in output]) inputs = list(itertools.chain(*[x['inputs'] for x in output])) translations = list( itertools.chain(*[x['translations'] for x in output])) ground_truths = list( itertools.chain(*[x['ground_truths'] for x in output])) assert len(translations) == len(inputs) assert len(translations) == len(ground_truths) # Gather translations and ground truths from all workers tr_gt_inp = [ None for _ in range(parallel_state.get_data_parallel_world_size()) ] # we also need to drop pairs where ground truth is an empty string torch.distributed.all_gather_object( tr_gt_inp, [(t, g, i) for (t, g, i) in zip(translations, ground_truths, inputs)], group=parallel_state.get_data_parallel_group(), ) if parallel_state.get_data_parallel_rank() == 0: _translations = [] _ground_truths = [] _inputs = [] # Deduplicate sentences that may have been distributed across multiple data parallel ranks. gt_inp_set = set() for rank in range( 0, parallel_state.get_data_parallel_world_size()): for t, g, i in tr_gt_inp[rank]: if g + i not in gt_inp_set: gt_inp_set.add(g + i) _translations.append(t) _ground_truths.append(g) _inputs.append(i) if self.tgt_language in ['ja']: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="ja-mecab") elif self.tgt_language in ['zh']: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="zh") else: sacre_bleu = corpus_bleu(_translations, [_ground_truths], tokenize="13a") bleu_score = sacre_bleu.score * parallel_state.get_data_parallel_world_size( ) dataset_name = "Validation" if mode == 'val' else "Test" logging.info( f"{dataset_name}, Dataloader index: {dataloader_idx}, Set size: {len(_translations)}" ) logging.info( f"{dataset_name}, Dataloader index: {dataloader_idx}, SacreBLEU = {bleu_score / parallel_state.get_data_parallel_world_size()}" ) logging.info( f"{dataset_name}, Dataloader index: {dataloader_idx}, Translation Examples:" ) logging.info( '============================================================' ) for example_idx in range(0, 3): random_index = random.randint(0, len(_translations) - 1) logging.info(" " + '\u0332'.join(f"Example {example_idx}:")) logging.info(f" Input: {_inputs[random_index]}") logging.info( f" Prediction: {_translations[random_index]}") logging.info( f" Ground Truth: {_ground_truths[random_index]}") logging.info( '============================================================' ) else: bleu_score = 0.0 loss_list.append(averaged_loss[0].cpu().numpy()) bleu_score_list.append(bleu_score) if dataloader_idx == 0: self.log(f'{mode}_sacreBLEU', bleu_score, sync_dist=True) self.log(f'{mode}_loss', averaged_loss[0], prog_bar=True) if self.multilingual: self._log_multilingual_bleu_and_loss( dataloader_idx, bleu_score, averaged_loss[0], mode) else: if self.multilingual: self._log_multilingual_bleu_and_loss( dataloader_idx, bleu_score, averaged_loss[0], mode) else: self.log(f'{mode}_sacreBLEU_dl_index_{dataloader_idx}', bleu_score, sync_dist=True) self.log(f'{mode}_loss_dl_index_{dataloader_idx}', averaged_loss[0], prog_bar=False) if len(loss_list) > 1: self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True) self.log(f"{mode}_sacreBLEU_avg", np.mean(bleu_score_list), sync_dist=True)
def validation_epoch_end(self, outputs): averaged_loss = average_losses_across_data_parallel_group(outputs) self.log('val_loss', averaged_loss[0], prog_bar=True) self.log('consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step))
def validation_epoch_end(self, outputs): averaged_loss = average_losses_across_data_parallel_group(outputs) # we can only log on one rank if it is rank zero so we broadcast from last rank torch.distributed.broadcast(averaged_loss, get_last_rank()) self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True)