def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Broadcast tensor values from last stage into the first stage.""" is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return tensor # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: if is_last_stage: _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() # Broadcast from last stage into the first stage. torch.distributed.broadcast(tensor, src, group) else: tensor = None return tensor
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Copy tensor values from last stage into the first stage. Note that the input tensor is updated in place.""" is_last_stage = mpu.is_pipeline_last_stage() is_first_stage = mpu.is_pipeline_first_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if is_first_stage and is_last_stage: return # Only first and last stage pipeline stages need to be involved. if is_last_stage or is_first_stage: _is_cuda(tensor) is_contiguous = tensor.is_contiguous() src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() if is_contiguous: tensor_ = tensor else: if is_last_stage: tensor_ = tensor.contiguous() else: tensor_ = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) # Broadcast from last stage into the first stage. torch.distributed.broadcast(tensor_, src, group) # Update the first stage tensor if is_first_stage and not is_contiguous: tensor[...] = tensor_
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): """Broadcast a tensor from last pipeline stage to all ranks.""" is_last_stage = mpu.is_pipeline_last_stage() # If first stage and last state are the same, then there is no # pipeline parallelism and no need to communicate. if mpu.is_pipeline_first_stage() and is_last_stage: return tensor if is_last_stage: _is_cuda_contiguous(tensor) else: tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device()) # Get the group and corresponding source rank. src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(tensor, src, group) return tensor
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, maxlen=None, type_ids=None): args = get_args() tokenizer = get_tokenizer() model.eval() with torch.no_grad(): context_length = context_lengths.min().item() eos_id = tokenizer.eod counter = 0 org_context_length = context_length layer_past = None batch_size = context_tokens.size(0) is_done = torch.zeros([batch_size]).byte().cuda() tokens = context_tokens if maxlen is None: maxlen = args.seq_length - 1 if maxlen > (org_context_length + args.out_seq_length): maxlen = org_context_length + args.out_seq_length lengths = torch.ones([batch_size]).long().cuda() * maxlen while context_length <= (maxlen): if args.recompute: output = forward_step(model, tokens, position_ids, attention_mask, tokentype_ids=type_ids, forward_method_parallel_output=False) if mpu.is_pipeline_last_stage(): assert output is not None logits = output[:, context_length - 1, :] else: types2use = None if counter == 0: tokens2use = tokens[:, :context_length] positions2use = position_ids[:, :context_length] if type_ids is not None: types2use = type_ids[:, :context_length] else: tokens2use = tokens[:, context_length - 1].view( batch_size, -1) positions2use = position_ids[:, context_length - 1].view( batch_size, -1) if type_ids is not None: types2use = type_ids[:, context_length - 1].view( batch_size, -1) output, layer_past = forward_step( model, tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use, forward_method_parallel_output=False) if mpu.is_pipeline_last_stage(): assert output is not None logits = output[:, -1].view(batch_size, -1).contiguous() if mpu.is_pipeline_last_stage(): if args.greedy: prev = torch.argmax(logits, dim=-1).view(-1) else: logits = logits.float() logits /= args.temperature logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1).view(-1) started = context_lengths <= context_length new_tokens = switch(tokens[:, context_length].view(-1), prev, started) tokens[:, context_length] = new_tokens src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() torch.distributed.broadcast(new_tokens, src, group) done_token = (prev == eos_id).byte() & started.byte() just_finished = (done_token & ~is_done).bool() lengths[just_finished.view(-1)] = context_length is_done = is_done | done_token done = torch.all(is_done) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) yield tokens, lengths else: if mpu.is_pipeline_first_stage(): src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_embedding_group() new_tokens = torch.empty_like(tokens[:, context_length]) torch.distributed.broadcast(new_tokens, src, group) tokens[:, context_length] = new_tokens yield tokens, None else: yield None, None done = torch.cuda.ByteTensor([0]) src = mpu.get_pipeline_model_parallel_last_rank() group = mpu.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) context_length += 1 counter += 1 if done: break