def seq2seq_forward_step(data, model, args, timers, mems): """Forward step.""" # Get the batch. if timers is not None: timers('batch generator').start() tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data, args) if timers is not None: timers('batch generator').stop() # Forward model. logits, *mems = model(tokens, position_ids, attention_mask, *mems) logits, loss_mask = logits[:, args. src_seq_length:], loss_mask[:, args. src_seq_length:] labels = labels[:, args.src_seq_length:] losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels) if args.label_smoothing > 0.0: epsilon = args.label_smoothing smooth_loss = -torch.nn.functional.log_softmax(logits, dim=-1).mean(dim=-1) losses = (1 - epsilon) * losses + epsilon * smooth_loss loss_mask = loss_mask.reshape(-1) # The loss is not normalized for fair comparison loss = torch.sum(losses.reshape(-1) * loss_mask) / loss_mask.sum() return loss, mems, 'bert'
def lm_forward_step(data, model, args, timers, mems, eval_metric=None): """Forward step.""" # Get the batch. tokens, labels, loss_mask, attention_mask, position_ids = get_batch( data, args) def print_masked_text(batch_id): block_position_ids = position_ids[:, 1] position_ids_ = position_ids[:, 0] output_tokens = [] sep = attention_mask[batch_id].item() for i, token in enumerate(tokens[batch_id, :sep].tolist()): if global_tokenizer is not None: token = global_tokenizer.IdToToken(token) if token.startswith('[MASK'): token = f"[{position_ids_[batch_id, i].item()}, {token}]" if token.startswith('##') and len( output_tokens) > 0 and not output_tokens[-1].endswith( ']'): output_tokens[-1] += token[2:] else: output_tokens.append(token) else: output_tokens.append(str(token)) print(" ".join(output_tokens)) last_index = None for i in range(sep, tokens.size(1)): if global_tokenizer.IdToToken( tokens[batch_id, i].item()).startswith("<|startofpiece"): if last_index is not None: print( global_tokenizer.DecodeIds( tokens[batch_id, last_index:i].tolist()), "|", global_tokenizer.DecodeIds( labels[batch_id, last_index:i].tolist())), print(position_ids_[batch_id, last_index:i].tolist(), block_position_ids[batch_id, last_index:i].tolist()) last_index = i if last_index is not None: print( global_tokenizer.DecodeIds(tokens[batch_id, last_index:].tolist()), "|", global_tokenizer.DecodeIds(labels[batch_id, last_index:].tolist())) print(position_ids_[batch_id, last_index:].tolist(), block_position_ids[batch_id, last_index:].tolist()) # Forward model. if args.continuous_prompt: prompt_pos = data["prompt_pos"].long().cuda() logits, *mems = model(tokens, position_ids, attention_mask, *mems, prompt_pos=prompt_pos) else: logits, *mems = model(tokens, position_ids, attention_mask, *mems) if eval_metric is None or eval_metric == 'loss': losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels) loss_mask = loss_mask.view(-1) # The loss is not normalized for fair comparison loss = torch.sum(losses.view(-1) * loss_mask) if eval_metric is None: loss = loss / loss_mask.sum() return loss, mems, 'bert' elif eval_metric == 'accuracy' or eval_metric == 'classify': outputs = torch.argmax(logits, -1) correct = (outputs == labels).float() correct[(1 - loss_mask).bool()] = 1 correct = correct.prod(-1) if eval_metric == 'accuracy': correct = correct.sum() return correct, mems, 'bert' else: raise NotImplementedError( "Metric {} not implemented".format(eval_metric))