def loss_func(loss_mask, sentence_order, output_tensor): lm_loss_, sop_logits = output_tensor lm_loss_ = lm_loss_.float() loss_mask = loss_mask.float() lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() if sop_logits is not None: sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) sop_loss = sop_loss.float() loss = lm_loss + sop_loss averaged_losses = average_losses_across_data_parallel_group( [lm_loss, sop_loss]) return loss, { 'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1] } else: loss = lm_loss averaged_losses = average_losses_across_data_parallel_group([lm_loss]) return loss, {'lm loss': averaged_losses[0]}
def _cross_entropy_forward_step(batch, model, input_tensor): """Simple forward step with cross-entropy loss.""" timers = get_timers() # Get the batch. timers('batch-generator').start() try: batch_ = next(batch) except BaseException: batch_ = batch tokens, types, labels, attention_mask = process_batch(batch_) timers('batch-generator').stop() # Forward model. if mpu.is_pipeline_first_stage(): assert input_tensor is None output_tensor = model(tokens, attention_mask, tokentype_ids=types) else: assert input_tensor is not None output_tensor = model(input_tensor, attention_mask) if mpu.is_pipeline_last_stage(): logits = output_tensor # Cross-entropy loss. loss_func = torch.nn.CrossEntropyLoss() loss = loss_func(logits.contiguous().float(), labels) # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) return loss, {'lm loss': averaged_loss[0]} return output_tensor
def forward_step(data_iterator, model, input_tensor): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch-generator').start() query_tokens, query_mask, \ context_tokens, context_mask, context_indices = get_ict_batch(data_iterator) timers('batch-generator').stop() # Query and Context Types query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0) context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0) # Forward model. query_logits, context_logits = model(query_tokens, query_mask, query_types, context_tokens, context_mask, context_types) micro_batch_size = query_logits.shape[0] # recall we assert that tensor_model_parallel_size == 1 assert mpu.get_tensor_model_parallel_world_size() == 1, \ "Model parallel size > 1 not supported for ICT" global_batch_size = dist.get_world_size() * micro_batch_size all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) # scores are inner products between query and context embeddings retrieval_scores = torch.matmul(all_query_logits, torch.transpose(all_context_logits, 0, 1)) # scaling the retriever scores if args.retriever_score_scaling: retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size) softmax_scores = F.log_softmax(retrieval_scores, dim=1) sorted_vals, sorted_indices = torch.topk(softmax_scores, k=softmax_scores.shape[1], sorted=True) def topk_accuracy(k): return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \ for i in range(global_batch_size)]) / global_batch_size]) topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies] labels = torch.arange(global_batch_size).long().cuda() loss = F.nll_loss(softmax_scores, labels, reduction='mean') reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs]) # Scale the retrieval loss loss = loss * mpu.get_data_parallel_world_size() # create stats_dict with retrieval loss and all specified top-k accuracies topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ zip(args.retriever_report_topk_accuracies, reduced_losses[1:])} stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict) return loss, stats_dict
def loss_func(loss_mask, output_tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) return loss, {'lm loss': averaged_loss[0]}
def loss_func(loss_mask, output_tensor): lm_loss_ = output_tensor.float() lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() loss = lm_loss averaged_losses = average_losses_across_data_parallel_group([lm_loss]) return loss, {'lm loss': averaged_losses[0]}
def cross_entropy_loss_func(labels, output_tensor): logits = output_tensor # Cross-entropy loss. loss = F.cross_entropy(logits.contiguous().float(), labels) # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model, input_tensor): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch-generator').start() tokens, loss_mask, lm_labels, padding_mask, attention_mask, position_ids \ = get_batch(data_iterator) timers('batch-generator').stop() extended_attention_mask = bert_extended_attention_mask( padding_mask) + attention_mask # Forward pass through the model. if mpu.is_pipeline_first_stage(): assert input_tensor is None if mpu.is_pipeline_last_stage(): output_tensor = model(tokens, extended_attention_mask, tokentype_ids=None, lm_labels=lm_labels, position_ids=position_ids) else: output_tensor = model(tokens, extended_attention_mask, tokentype_ids=None) elif mpu.is_pipeline_last_stage(): assert input_tensor is not None output_tensor = model(input_tensor, extended_attention_mask, lm_labels=lm_labels) else: assert input_tensor is not None output_tensor = model(input_tensor, extended_attention_mask, position_ids=position_ids) if mpu.is_pipeline_last_stage(): lm_loss_, _ = output_tensor lm_loss_ = lm_loss_.float() loss_mask = loss_mask.float() lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() loss = lm_loss averaged_losses = average_losses_across_data_parallel_group([ lm_loss, ]) return loss, {'lm loss': averaged_losses[0]} return output_tensor
def loss_func(labels, output_tensor): logits = output_tensor.contiguous().float() loss = F.cross_entropy(logits, labels) outputs = torch.argmax(logits, -1) correct = (outputs == labels).float() accuracy = torch.mean(correct) averaged_loss = average_losses_across_data_parallel_group([loss, accuracy]) return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def forward_step(data_iterator, model, input_tensor): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch-generator').start() tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \ = get_batch(data_iterator) timers('batch-generator').stop() # Forward pass through the model. if mpu.is_pipeline_first_stage(): assert input_tensor is None if mpu.is_pipeline_last_stage(): output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=lm_labels) else: output_tensor = model(tokens, padding_mask, tokentype_ids=types) elif mpu.is_pipeline_last_stage(): assert input_tensor is not None output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels) else: assert input_tensor is not None output_tensor = model(input_tensor, padding_mask) if mpu.is_pipeline_last_stage(): lm_loss_, sop_logits = output_tensor sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) sop_loss = sop_loss.float() lm_loss_ = lm_loss_.float() loss_mask = loss_mask.float() lm_loss = torch.sum( lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() loss = lm_loss + sop_loss averaged_losses = average_losses_across_data_parallel_group( [lm_loss, sop_loss]) return loss, { 'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1] } return output_tensor
def forward_step(data_iterator, model, input_tensor): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch-generator').start() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data_iterator) timers('batch-generator').stop() # Forward pass through the model. if mpu.is_pipeline_first_stage(): assert input_tensor is None if mpu.is_pipeline_last_stage(): output_tensor = model(tokens, position_ids, attention_mask, labels=labels) else: output_tensor = model(tokens, position_ids, attention_mask) elif mpu.is_pipeline_last_stage(): assert input_tensor is not None output_tensor = model(input_tensor, attention_mask, labels=labels) else: assert input_tensor is not None output_tensor = model(input_tensor, attention_mask) if mpu.is_pipeline_last_stage(): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) return loss, {'lm loss': averaged_loss[0]} return output_tensor
def forward_step(data_iterator, model, input_tensor): """Forward step.""" args = get_args() timers = get_timers() # Get the batch. timers('batch-generator').start() query_tokens, query_pad_mask, \ block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator) timers('batch-generator').stop() # Forward model. query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) micro_batch_size = query_logits.shape[0] global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1 all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) # scores are inner products between query and block embeddings retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float()) softmaxed = F.softmax(retrieval_scores, dim=1) sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True) def topk_accuracy(k): return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size]) topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies] retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda()) retrieval_loss = retrieval_loss.float() averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs]) # create stats_dict with retrieval loss and all specified top-k accuracies topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])} stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict) return retrieval_loss, stats_dict
def forward_step(data_iterator, model, input_tensor): """Forward step.""" timers = get_timers() assert input_tensor is None # Get the batch. timers("batch-generator").start() ( images, labels, ) = get_batch(data_iterator) timers("batch-generator").stop() # Forward model. lm_labels logits = model(images).contiguous().float() loss = F.cross_entropy(logits, labels) outputs = torch.argmax(logits, -1) correct = (outputs == labels).float() accuracy = torch.mean(correct) averaged_loss = average_losses_across_data_parallel_group([loss, accuracy]) return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def _cross_entropy_forward_step(batch, model, input_tensor): """Simple forward step with cross-entropy loss.""" timers = get_timers() assert input_tensor is None # Get the batch. timers("batch generator").start() try: batch_ = next(batch) except BaseException: batch_ = batch images, labels = process_batch(batch_) timers("batch generator").stop() # Forward model. logits = model(images).contiguous().float() # Cross-entropy loss. loss = F.cross_entropy(logits, labels) # Reduce loss for logging. average_loss = average_losses_across_data_parallel_group([loss]) return loss, {"lm loss": average_loss[0]}
def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor): args = get_args() local_batch_size = query_tokens.shape[0] group, rank, world_size = get_group_world_size_rank() # recall we assert that model_parallel_size == 1 global_batch_size = world_size * local_batch_size query_logits, context_logits = output_tensor if world_size > 1: input_ = torch.empty_like(context_logits).copy_(\ context_logits).detach_() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank].copy_(input_) torch.distributed.all_gather(tensor_list, input_, group=group) # Check if all-gather happens in order assert tensor_list[rank].sum().item() == \ context_logits.sum().item() # Preserves the gradient tensor_list[rank] = context_logits all_context_logits = torch.cat(tensor_list, dim=0).contiguous() # Query tensors input_ = torch.empty_like(query_logits).copy_(\ query_logits).detach_() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank].copy_(input_) torch.distributed.all_gather(tensor_list, input_, group=group) # Check if all-gather happens in order assert tensor_list[rank].sum().item() == query_logits.sum().item() # Preserves the gradient tensor_list[rank] = query_logits all_query_logits = torch.cat(tensor_list, dim=0).contiguous() else: all_query_logits = query_logits all_context_logits = context_logits retrieval_scores = torch.matmul( all_query_logits, torch.transpose(all_context_logits, 0, 1)) # Scaling the retrieval scores if args.retriever_score_scaling: retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size) if args.train_with_neg: # if the world size is 3, local batch size is 4, and # local context size is 8, what we want is # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19] labels = [] local_context_size = context_tokens.shape[0] for i in range(world_size): j = i * local_context_size labels.extend(list(range(j, j + local_batch_size))) labels = torch.LongTensor(labels).cuda() assert len(labels) == global_batch_size else: labels = torch.arange(global_batch_size).long().cuda() # Cross-entropy loss. softmax_scores = F.log_softmax(retrieval_scores, dim=1) loss = F.nll_loss(softmax_scores, labels, reduction='mean') max_score, max_idxs = torch.max(softmax_scores, 1) correct_predictions_count = (max_idxs == labels).sum().float() # Reduce loss for logging. reduced_loss = average_losses_across_data_parallel_group([loss, \ correct_predictions_count]) # Loss scaling for correct losses in Supervised Retrieval loss = loss * mpu.get_data_parallel_world_size() return loss, { 'lm loss': reduced_loss[0], 'correct_prediction_count': reduced_loss[1] }
def retrieval_loss(model, dataloader): args = get_args() total = 0 topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \ args.retriever_report_topk_accuracies} stats_dict = dict(rank=0, **topk_stats_dict) assert len(model) == 1 unwrapped_model = model[0] unwrapped_model.eval() with torch.no_grad(): # For all the batches in the dataset. for batch in dataloader: # Run the model forward. query_tokens, query_mask, query_types, _, \ context_tokens, context_mask, context_types, _, \ neg_context_tokens, neg_context_mask, neg_context_types, \ reference = process_batch(batch) query_logits, context_logits = unwrapped_model(query_tokens, query_mask, query_types, torch.cat([context_tokens, neg_context_tokens]), torch.cat([context_mask, neg_context_mask]), torch.cat([context_types, neg_context_types])) retrieval_scores = torch.matmul(query_logits, torch.transpose(context_logits, 0, 1)) if args.retriever_score_scaling: retrieval_scores = retrieval_scores / \ math.sqrt(args.hidden_size) local_batch_size = query_logits.shape[0] labels = torch.arange(local_batch_size).long().cuda() softmax_scores = F.softmax(retrieval_scores, dim=1) sorted_vals, sorted_indices = torch.topk(softmax_scores, k=softmax_scores.shape[1], sorted=True) def topk_accuracy(k): return torch.cuda.FloatTensor( [sum([int(labels[i] in sorted_indices[i, :k]) for i in \ range(local_batch_size)])]) def get_rank(): return torch.cuda.FloatTensor( [sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \ for i in range(local_batch_size)])]) topk_accs = [topk_accuracy(k) for k in \ args.retriever_report_topk_accuracies] rank = get_rank() losses = average_losses_across_data_parallel_group([rank, \ *topk_accs]) # create stats_dict with retrieval loss and all specified # top-k accuracies topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ zip(args.retriever_report_topk_accuracies, losses[1:])} temp_stats_dict = dict(rank=losses[0], **topk_acc_dict) for k in stats_dict.keys(): stats_dict[k] += temp_stats_dict[k] total += local_batch_size unwrapped_model.train() return stats_dict, total