示例#1
0
 def initialize_word_embeddings(self, init_method_normal):
     args = get_args()
     if not self.share_word_embeddings:
         raise Exception('initialize_word_embeddings() was called but '
                         'share_word_embeddings is false')
     # Parameters are shared between the word embeddings layer, and the
     # heads at the end of the model. In a pipelined setup with more than
     # one stage, the initial embedding layer and the head are on different
     # workers, so we do the following:
     # 1. Create a second copy of word_embeddings on the last stage, with
     #    initial parameters of 0.0.
     # 2. Do an all-reduce between the first and last stage to ensure that
     #    the two copies of word_embeddings start off with the same
     #    parameter values.
     # 3. In the training loop, before an all-reduce between the grads of
     #    the two word_embeddings layers to ensure that every applied weight
     #    update is the same on both stages.
     if mpu.is_pipeline_last_stage():
         if not mpu.is_pipeline_first_stage():
             self._word_embeddings_for_head_key = 'word_embeddings_for_head'
             # If first and last stages are different, set word_embeddings
             # weights to 0 here, then copy first stage's weights using
             # all_reduce below.
             self.word_embeddings = mpu.VocabParallelEmbedding(
                 args.padded_vocab_size,
                 args.hidden_size,
                 init_method=init_method_normal(args.init_method_std))
             self.word_embeddings.weight.data.fill_(0)
             self.word_embeddings.weight.shared = True
     # Ensure that first and last stages have the same initial parameter
     # values.
     if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
         torch.distributed.all_reduce(self.word_embeddings_weight().data,
                                      group=mpu.get_embedding_group())
示例#2
0
    def forward(self, bert_model_input, attention_mask,
                tokentype_ids=None, lm_labels=None, position_ids=None):

        extended_attention_mask = bert_extended_attention_mask(attention_mask) if attention_mask.dim() == 2 else attention_mask

        kwargs = {}
        if mpu.is_pipeline_first_stage():
            input_ids = bert_model_input
            if position_ids is None:
                position_ids = bert_position_ids(input_ids)
            args = [input_ids, position_ids, extended_attention_mask]
            kwargs['tokentype_ids'] = tokentype_ids
        else:
            args = [bert_model_input, extended_attention_mask]
        lm_output = self.language_model(*args, **kwargs)
        if mpu.is_pipeline_last_stage() and self.add_binary_head:
            lm_output, pooled_output = lm_output
        else:
            pooled_output = None

        if mpu.is_pipeline_last_stage():
            return post_language_model_processing(lm_output, pooled_output,
                                                  self.lm_head, self.binary_head,
                                                  lm_labels,
                                                  self.word_embeddings_weight(),
                                                  self.fp16_lm_cross_entropy)
        else:
            return lm_output
示例#3
0
def forward_step(batch, model, eval_metric):
    """Forward step."""

    # Get the batch.
    tokens, labels, attention_mask, position_ids, loss_mask = process_batch(
        batch)

    # Tell the model what our actual batch size will be
    args = get_args()
    args.micro_batch_size = len(labels)

    # Forward model.
    if not mpu.is_pipeline_first_stage():
        input_tensor, _ = communicate(tensor_send_next=None,
                                      tensor_send_prev=None,
                                      recv_forward=True,
                                      recv_backward=False)
    else:
        input_tensor = None

    # Forward pass through the model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        if mpu.is_pipeline_last_stage():
            output = model(tokens, position_ids, attention_mask)
        else:
            output = model(tokens, position_ids, attention_mask)
    else:
        assert input_tensor is not None
        output = model(input_tensor, attention_mask)

    if not mpu.is_pipeline_last_stage():
        communicate(tensor_send_next=output,
                    tensor_send_prev=None,
                    recv_forward=False,
                    recv_backward=False)
        return None

    if mpu.is_pipeline_last_stage():
        # For loss, return the unreduced loss.
        if eval_metric == 'loss':
            losses = mpu.vocab_parallel_cross_entropy(
                output.contiguous().float(), labels.contiguous())
            loss = torch.sum(
                losses.view(-1) * loss_mask.contiguous().view(-1).float())
            return loss

        # For accuracy, return the number of correctly predicted samples.
        if eval_metric == 'accuracy':
            outputs = torch.argmax(output, -1)
            correct = (outputs == labels).float()
            correct[(1 - loss_mask).bool()] = 1
            correct = correct.prod(-1)
            return correct.sum()

        raise NotImplementedError('forward method for evaluation metric {} '
                                  'is not implemented.'.format(eval_metric))
    return None
示例#4
0
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
                                  inference_params, micro_batch_size):
    """No interleaving is supported."""
    sequence_length = tokens.size(1)
    batch_size = tokens.size(0)

    # Divide the batch dimension into micro batches.
    num_micro_batches, last_chunk = divmod(batch_size, micro_batch_size)
    if last_chunk > 0:
        num_micro_batches += 1

    # Preallocate memory for output logits.
    logits = None
    if mpu.is_pipeline_last_stage():
        args = get_args()
        logits = torch.empty(
            (batch_size, sequence_length, args.padded_vocab_size),
            dtype=torch.float32,
            device=torch.cuda.current_device())

    # Preallocate recv buffer.
    recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)

    for micro_batch_index in range(num_micro_batches):
        # Slice among the batch dimenion.
        start = micro_batch_index * micro_batch_size
        end = min(start + micro_batch_size, batch_size)
        this_micro_batch_size = end - start
        tokens2use = tokens[start:end, ...]
        position_ids2use = position_ids[start:end, ...]

        # Run a simple forward pass.
        if this_micro_batch_size != micro_batch_size:
            recv_buffer = None
        output = _forward_step_helper(model,
                                      tokens2use,
                                      position_ids2use,
                                      attention_mask,
                                      inference_params,
                                      recv_buffer=recv_buffer)

        # Adjust the batch size offset to account for the micro-batch.
        inference_params.batch_size_offset += this_micro_batch_size

        # Copy logits.
        if mpu.is_pipeline_last_stage():
            logits[start:end, ...] = output

    # Once we are done with all the micro-batches, we can
    # adjust the sequence length offset.
    inference_params.sequence_len_offset += sequence_length
    # and reset the batch size offset
    inference_params.batch_size_offset = 0

    return logits
示例#5
0
def forward_step(data_iterator, model, input_tensor):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    tokens, loss_mask, lm_labels, padding_mask, attention_mask, position_ids \
        = get_batch(data_iterator)
    timers('batch-generator').stop()

    extended_attention_mask = bert_extended_attention_mask(
        padding_mask) + attention_mask

    # Forward pass through the model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        if mpu.is_pipeline_last_stage():
            output_tensor = model(tokens,
                                  extended_attention_mask,
                                  tokentype_ids=None,
                                  lm_labels=lm_labels,
                                  position_ids=position_ids)
        else:
            output_tensor = model(tokens,
                                  extended_attention_mask,
                                  tokentype_ids=None)
    elif mpu.is_pipeline_last_stage():
        assert input_tensor is not None
        output_tensor = model(input_tensor,
                              extended_attention_mask,
                              lm_labels=lm_labels)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor,
                              extended_attention_mask,
                              position_ids=position_ids)

    if mpu.is_pipeline_last_stage():
        lm_loss_, _ = output_tensor

        lm_loss_ = lm_loss_.float()
        loss_mask = loss_mask.float()
        lm_loss = torch.sum(
            lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

        loss = lm_loss

        averaged_losses = average_losses_across_data_parallel_group([
            lm_loss,
        ])

        return loss, {'lm loss': averaged_losses[0]}
    return output_tensor
示例#6
0
def forward_step(data_iterator, model, input_tensor):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
        = get_batch(data_iterator)
    timers('batch-generator').stop()

    # Forward pass through the model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        if mpu.is_pipeline_last_stage():
            output_tensor = model(tokens,
                                  padding_mask,
                                  tokentype_ids=types,
                                  lm_labels=lm_labels)
        else:
            output_tensor = model(tokens, padding_mask, tokentype_ids=types)
    elif mpu.is_pipeline_last_stage():
        assert input_tensor is not None
        output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, padding_mask)

    if mpu.is_pipeline_last_stage():
        lm_loss_, sop_logits = output_tensor

        sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                                   sentence_order.view(-1),
                                   ignore_index=-1)
        sop_loss = sop_loss.float()

        lm_loss_ = lm_loss_.float()
        loss_mask = loss_mask.float()
        lm_loss = torch.sum(
            lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

        loss = lm_loss + sop_loss

        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss, sop_loss])

        return loss, {
            'lm loss': averaged_losses[0],
            'sop loss': averaged_losses[1]
        }
    return output_tensor
示例#7
0
    def initialize_word_embeddings(self, init_method_normal):
        args = get_args()
        if not self.share_word_embeddings:
            raise Exception('initialize_word_embeddings() was called but '
                            'share_word_embeddings is false')

        # This function just initializes the word embeddings in the final stage
        # when we are using pipeline parallelism. If we aren't using pipeline
        # parallelism there is nothing to do.
        if args.pipeline_model_parallel_size == 1:
            return

        # Parameters are shared between the word embeddings layer, and the
        # heads at the end of the model. In a pipelined setup with more than
        # one stage, the initial embedding layer and the head are on different
        # workers, so we do the following:
        # 1. Create a second copy of word_embeddings on the last stage, with
        #    initial parameters of 0.0.
        # 2. Do an all-reduce between the first and last stage to ensure that
        #    the two copies of word_embeddings start off with the same
        #    parameter values.
        # 3. In the training loop, before an all-reduce between the grads of
        #    the two word_embeddings layers to ensure that every applied weight
        #    update is the same on both stages.
        if mpu.is_pipeline_last_stage():
            assert not mpu.is_pipeline_first_stage()
            self._word_embeddings_for_head_key = 'word_embeddings_for_head'
            # set word_embeddings weights to 0 here, then copy first
            # stage's weights using all_reduce below.
            self.word_embeddings = mpu.VocabParallelEmbedding(
                args.padded_vocab_size,
                args.hidden_size,
                init_method=init_method_normal(args.init_method_std))
            self.word_embeddings.weight.data.fill_(0)
            self.word_embeddings.weight.shared = True

        # Ensure that first and last stages have the same initial parameter
        # values.
        if torch.distributed.is_initialized():
            if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
                torch.distributed.all_reduce(
                    self.word_embeddings_weight().data,
                    group=mpu.get_embedding_group())
        else:
            print("WARNING! Distributed processes aren't initialized, so "
                  "word embeddings in the last layer are not initialized. "
                  "If you are just manipulating a model this is fine, but "
                  "this needs to be handled manually. If you are training "
                  "something is definitely wrong.")
示例#8
0
    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        self.language_model.load_state_dict(
            state_dict[self._language_model_key], strict=strict)
        if mpu.is_pipeline_last_stage():
            self.lm_head.load_state_dict(
                state_dict[self._lm_head_key], strict=strict)
        if mpu.is_pipeline_last_stage() and self.add_binary_head:
            self.binary_head.load_state_dict(
                state_dict[self._binary_head_key], strict=strict)
        # Load word_embeddings.
        if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
            self.word_embeddings.load_state_dict(
                state_dict[self._word_embeddings_for_head_key], strict=strict)
示例#9
0
    def model_provider():
        """Build the model."""

        if eval_metric == 'loss':
            parallel_output = True
        elif eval_metric == 'accuracy':
            parallel_output = False
        else:
            raise NotImplementedError('output type for {} evaluation metric '
                                      'is not supported.'.format(eval_metric))

        print_rank_0('building GPT2 model ...')
        if mpu.get_pipeline_model_parallel_world_size() > 1:
            # Determine model based on position of stage in pipeline.
            if mpu.is_pipeline_first_stage():
                model = GPT2ModelFirstStage(num_tokentypes=0)
            elif mpu.is_pipeline_last_stage():
                model = GPT2ModelLastStage(parallel_output=parallel_output,
                                           num_tokentypes=0)
            else:
                model = GPT2ModelIntermediateStage(num_tokentypes=0)
        else:
            model = GPT2Model(num_tokentypes=0,
                              parallel_output=parallel_output)

        return model
示例#10
0
    def __init__(self, attention_mask_func,
                 init_method, output_layer_init_method):
        super(ParallelTransformer, self).__init__()
        args = get_args()

        self.fp32_residual_connection = args.fp32_residual_connection

        # Store activation checkpoiting flag.
        self.checkpoint_activations = args.checkpoint_activations
        self.checkpoint_num_layers = args.checkpoint_num_layers

        # Number of layers.
        assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
            'num_layers must be divisible by pipeline_model_parallel_size'
        self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()

        # Transformer layers.
        def build_layer(layer_number):
            return ParallelTransformerLayer(
                attention_mask_func, init_method,
                output_layer_init_method, layer_number)
        offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
        self.layers = torch.nn.ModuleList(
            [build_layer(i + 1 + offset) for i in range(self.num_layers)])

        if mpu.is_pipeline_last_stage():
            # Final layer norm before output.
            LayerNorm = import_layernorm(args.fp32_residual_connection)
            self.final_layernorm = LayerNorm(
                args.hidden_size,
                eps=args.layernorm_epsilon)
示例#11
0
    def forward(self, model_input, attention_mask, tokentype_ids=None):

        extended_attention_mask = bert_extended_attention_mask(attention_mask)

        kwargs = {}
        if mpu.is_pipeline_first_stage():
            input_ids = model_input
            position_ids = bert_position_ids(input_ids)

            args = [input_ids, position_ids, extended_attention_mask]
            kwargs['tokentype_ids'] = tokentype_ids
        else:
            args = [model_input, extended_attention_mask]
        lm_output = self.language_model(*args, **kwargs)
        if mpu.is_pipeline_last_stage():
            _, pooled_output = lm_output
            classification_output = self.classification_dropout(pooled_output)
            classification_logits = self.classification_head(
                classification_output)

            # Reshape back to separate choices.
            classification_logits = classification_logits.view(
                -1, self.num_classes)

            return classification_logits
        return lm_output
示例#12
0
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Broadcast tensor values from last stage into the first stage."""

    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return tensor
    # Only first and last stage pipeline stages need to be involved.
    if is_last_stage or is_first_stage:
        if is_last_stage:
            _is_cuda_contiguous(tensor)
        else:
            tensor = torch.empty(size,
                                 dtype=dtype,
                                 device=torch.cuda.current_device())
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor, src, group)
    else:
        tensor = None

    return tensor
def _cross_entropy_forward_step(batch, model, input_tensor):
    """Simple forward step with cross-entropy loss."""
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    try:
        batch_ = next(batch)
    except BaseException:
        batch_ = batch
    tokens, types, labels, attention_mask = process_batch(batch_)
    timers('batch-generator').stop()

    # Forward model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        output_tensor = model(tokens, attention_mask, tokentype_ids=types)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask)

    if mpu.is_pipeline_last_stage():
        logits = output_tensor

        # Cross-entropy loss.
        loss_func = torch.nn.CrossEntropyLoss()
        loss = loss_func(logits.contiguous().float(), labels)

        # Reduce loss for logging.
        averaged_loss = average_losses_across_data_parallel_group([loss])

        return loss, {'lm loss': averaged_loss[0]}
    return output_tensor
示例#14
0
    def forward_step_with_balance_loss(data_iterator, model, input_tensor):
        args = get_args()
        output = forward_step_func(data_iterator, model, input_tensor)

        if is_pipeline_last_stage():
            loss_name = args.balance_strategy + "_loss"

            (loss, state_dict), bal_loss = (
                output,
                (torch.tensor(
                    balance_dict[loss_name],
                    device=balance_dict[loss_name][0].device,
                ).mean() * args.balance_loss_weight).float(),
            )

            # avarage across world group
            world_group = get_torch_default_comm()
            world_size = torch.distributed.get_world_size(group=world_group)
            averaged_bal_loss = bal_loss.clone().detach()
            torch.distributed.all_reduce(averaged_bal_loss, group=world_group)
            averaged_bal_loss /= world_size

            loss += bal_loss
            state_dict[loss_name] = averaged_bal_loss

            return loss, state_dict
        else:
            return output
示例#15
0
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
    """Copy tensor values from last stage into the first stage.
    Note that the input tensor is updated in place."""

    is_last_stage = mpu.is_pipeline_last_stage()
    is_first_stage = mpu.is_pipeline_first_stage()
    # If first stage and last state are the same, then there is no
    # pipeline parallelism and no need to communicate.
    if is_first_stage and is_last_stage:
        return
    # Only first and last stage pipeline stages need to be involved.
    if is_last_stage or is_first_stage:
        _is_cuda(tensor)
        is_contiguous = tensor.is_contiguous()
        src = mpu.get_pipeline_model_parallel_last_rank()
        group = mpu.get_embedding_group()
        if is_contiguous:
            tensor_ = tensor
        else:
            if is_last_stage:
                tensor_ = tensor.contiguous()
            else:
                tensor_ = torch.empty(size,
                                      dtype=dtype,
                                      device=torch.cuda.current_device())
        # Broadcast from last stage into the first stage.
        torch.distributed.broadcast(tensor_, src, group)
        # Update the first stage tensor
        if is_first_stage and not is_contiguous:
            tensor[...] = tensor_
    def load_state_dict(self, state_dict, strict=True):
        """Customized load."""

        # Embedding.
        if mpu.is_pipeline_first_stage():
            if self._embedding_key in state_dict:
                state_dict_ = state_dict[self._embedding_key]
            else:
                # for backward compatibility.
                state_dict_ = {}
                for key in state_dict.keys():
                    if '_embeddings' in key:
                        state_dict_[key] = state_dict[key]
            self.embedding.load_state_dict(state_dict_, strict=strict)

        # Transformer.
        if self._transformer_key in state_dict:
            state_dict_ = state_dict[self._transformer_key]
        else:
            # for backward compatibility.
            state_dict_ = {}
            for key in state_dict.keys():
                if 'transformer.' in key:
                    state_dict_[key.split('transformer.')[1]] = state_dict[key]
        self.transformer.load_state_dict(state_dict_, strict=strict)

        # Pooler.
        if mpu.is_pipeline_last_stage() and self.add_pooler:
            assert 'pooler' in state_dict, \
                'could not find data for pooler in the checkpoint'
            self.pooler.load_state_dict(state_dict[self._pooler_key],
                                        strict=strict)
    def __init__(self,
                 attention_mask_func,
                 init_method,
                 output_layer_init_method,
                 num_tokentypes=0,
                 add_pooler=False):
        super(TransformerLanguageModelBase, self).__init__()
        args = get_args()

        self.hidden_size = args.hidden_size
        self.num_tokentypes = num_tokentypes
        self.init_method = init_method
        self.add_pooler = add_pooler

        # Embeddings.
        if mpu.is_pipeline_first_stage():
            self.embedding = Embedding(self.hidden_size,
                                       args.padded_vocab_size,
                                       args.max_position_embeddings,
                                       args.hidden_dropout, self.init_method,
                                       self.num_tokentypes)
            self._embedding_key = 'embedding'

        # Transformer.
        self.transformer = ParallelTransformer(attention_mask_func,
                                               self.init_method,
                                               output_layer_init_method)
        self._transformer_key = 'transformer'

        # Pooler.
        if mpu.is_pipeline_last_stage() and self.add_pooler:
            self.pooler = Pooler(self.hidden_size, self.init_method)
            self._pooler_key = 'pooler'
    def forward(self,
                language_model_input,
                attention_mask,
                tokentype_ids=None,
                layer_past=None,
                get_key_value=False,
                pooling_sequence_index=0):

        # Embeddings.
        if mpu.is_pipeline_first_stage():
            (input_ids, position_ids) = language_model_input
            embedding_output = self.embedding(input_ids,
                                              position_ids,
                                              tokentype_ids=tokentype_ids)
            transformer_input = embedding_output
        else:
            transformer_input = language_model_input

        # Transformer.
        transformer_output = self.transformer(transformer_input,
                                              attention_mask,
                                              layer_past=layer_past,
                                              get_key_value=get_key_value)

        if mpu.is_pipeline_last_stage() and self.add_pooler:
            pooled_output = self.pooler(transformer_output,
                                        pooling_sequence_index)
            return transformer_output, pooled_output

        return transformer_output
示例#19
0
 def forward(self, *inputs, **kwargs):
     if mpu.is_pipeline_first_stage():
         inputs = fp32_to_fp16(inputs)
     outputs = self.module(*inputs, **kwargs)
     if mpu.is_pipeline_last_stage():
         outputs = fp16_to_fp32(outputs)
     return outputs
示例#20
0
def backward_step_with_communication(optimizer, model, input_tensors,
                                     output_tensors, timers):
    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
        timers('backward-recv').start()
        _, output_tensor_grad = communicate(tensor_send_next=None,
                                            tensor_send_prev=None,
                                            recv_forward=False,
                                            recv_backward=True)
        timers('backward-recv').stop()

    # Backward pass for one step.
    timers('backward-compute').start()
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
    timers('backward-compute').stop()

    if not mpu.is_pipeline_first_stage():
        timers('backward-send').start()
        communicate(tensor_send_next=None,
                    tensor_send_prev=input_grad_tensor,
                    recv_forward=False,
                    recv_backward=False)
        timers('backward-send').stop()
示例#21
0
    def __init__(self, num_tokentypes=2, add_binary_head=True,
                 parallel_output=True):
        super(BertModelBase, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        self.add_binary_head = add_binary_head
        self.parallel_output = parallel_output

        init_method = init_method_normal(args.init_method_std)
        scaled_init_method = scaled_init_method_normal(args.init_method_std,
                                                       args.num_layers)

        self.language_model, self._language_model_key = get_language_model(
            attention_mask_func=bert_attention_mask_func,
            num_tokentypes=num_tokentypes,
            add_pooler=self.add_binary_head,
            init_method=init_method,
            scaled_init_method=scaled_init_method)

        self.initialize_word_embeddings(init_method_normal)
        if mpu.is_pipeline_last_stage():
            self.lm_head = BertLMHead(
                self.word_embeddings_weight().size(0),
                args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
            self._lm_head_key = 'lm_head'
            self.binary_head = None
            if self.add_binary_head:
                self.binary_head = get_linear_layer(args.hidden_size, 2,
                                                    init_method)
                self._binary_head_key = 'binary_head'
示例#22
0
def forward_step_with_communication(forward_step_func, data_iterator, model,
                                    input_tensors, output_tensors,
                                    losses_reduced, timers):
    args = get_args()

    if not mpu.is_pipeline_first_stage():
        timers('forward-recv').start()
        input_tensor, _ = communicate(tensor_send_next=None,
                                      tensor_send_prev=None,
                                      recv_forward=True,
                                      recv_backward=False)
        timers('forward-recv').stop()
    else:
        input_tensor = None

    # Forward model for one step.
    timers('forward-compute').start()
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
    timers('forward-compute').stop()

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
        output_tensor = loss / get_num_microbatches()
        losses_reduced.append(loss_reduced)
    else:
        timers('forward-send').start()
        communicate(tensor_send_next=output_tensor,
                    tensor_send_prev=None,
                    recv_forward=False,
                    recv_backward=False)
        timers('forward-send').stop()

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)
示例#23
0
def generate_and_write_samples_unconditional(model):

    args = get_args()
    assert args.genfile is not None
    with open(args.genfile, 'w') as f:
        for datum in generate_samples_unconditional(model):
            if mpu.is_pipeline_last_stage() and \
               mpu.get_tensor_model_parallel_rank() == 0:
                f.write(json.dumps(datum) + '\n')
示例#24
0
 def word_embeddings_weight(self):
     if mpu.is_pipeline_first_stage():
         return self.language_model.embedding.word_embeddings.weight
     if mpu.is_pipeline_last_stage():
         if not self.share_word_embeddings:
             raise Exception('word_embeddings_weight() called for last '
                             'stage, but share_word_embeddings is false')
         return self.word_embeddings.weight
     raise Exception('word_embeddings_weight() should be '
                     'called for first and last stage only')
示例#25
0
def send_forward(output_tensor, timers=None):
    """Send tensor to next rank in pipeline (forward send)."""
    if not mpu.is_pipeline_last_stage():
        if timers is not None:
            timers('forward-send').start()
        _communicate(tensor_send_next=output_tensor,
                     tensor_send_prev=None,
                     recv_prev=False,
                     recv_next=False)
        if timers is not None:
            timers('forward-send').stop()
示例#26
0
def evaluate(forward_step_func, data_iterator, model, verbose=False):
    """Evaluation."""
    args = get_args()

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))

            for _ in range(get_num_microbatches()):
                if not mpu.is_pipeline_first_stage():
                    input_tensor, _ = communicate(
                        tensor_send_next=None,
                        tensor_send_prev=None,
                        recv_forward=True,
                        recv_backward=False)
                else:
                    input_tensor = None

                # Forward evaluation.
                output_tensor = forward_step_func(data_iterator, model, input_tensor)

                if mpu.is_pipeline_last_stage():
                    _, loss_dict = output_tensor
                    # Reduce across processes.
                    for key in loss_dict:
                        total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
                            loss_dict[key]
                else:
                    communicate(
                        tensor_send_next=output_tensor,
                        tensor_send_prev=None,
                        recv_forward=False,
                        recv_backward=False)

            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
                                           * args.micro_batch_size \
                                           * get_num_microbatches()
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()

    return total_loss_dict
示例#27
0
def send_to_next_pipeline_rank(tensor=None):
    """Send output to the next pipeline stage."""
    if not mpu.is_pipeline_last_stage():
        assert tensor is not None
        send_next_op = torch.distributed.P2POp(
            torch.distributed.isend, tensor,
            mpu.get_pipeline_model_parallel_next_rank())
        reqs = torch.distributed.batch_isend_irecv([send_next_op])
        for req in reqs:
            req.wait()
        # To protect against race condition when using batch_isend_irecv().
        torch.cuda.synchronize()
示例#28
0
    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        """For easy load when model is combined with other heads,
        add an extra key."""

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
            destination, prefix, keep_vars)
        if mpu.is_pipeline_last_stage():
            state_dict_[self._lm_head_key] \
                = self.lm_head.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        if mpu.is_pipeline_last_stage() and self.add_binary_head:
            state_dict_[self._binary_head_key] \
                = self.binary_head.state_dict(destination, prefix, keep_vars)
        # Save word_embeddings.
        if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
            state_dict_[self._word_embeddings_for_head_key] \
                = self.word_embeddings.state_dict(destination, prefix, keep_vars)
        return state_dict_
示例#29
0
    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):

        state_dict_ = {}
        state_dict_[self._language_model_key] \
            = self.language_model.state_dict_for_save_checkpoint(
                destination, prefix, keep_vars)
        # Save word_embeddings.
        if mpu.is_pipeline_last_stage() and not mpu.is_pipeline_first_stage():
            state_dict_[self._word_embeddings_for_head_key] \
                = self.word_embeddings.state_dict(destination, prefix, keep_vars)
        return state_dict_
示例#30
0
    def forward(self, hidden_states, attention_mask, layer_past=None,
                get_key_value=False):

        # Checks.
        if layer_past is not None:
            assert get_key_value, \
                'for not None values in layer_past, ' \
                'expected get_key_value to be set'
        if get_key_value:
            assert not self.checkpoint_activations, \
                'get_key_value does not work with ' \
                'activation checkpointing'

        if mpu.is_pipeline_first_stage():
            # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
            # If the input flag for fp32 residual connection is set, convert for float.
            if self.fp32_residual_connection:
                hidden_states = hidden_states.transpose(0, 1).contiguous().float()
            # Otherwise, leave it as is.
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()

        if self.checkpoint_activations:
            hidden_states = self._checkpointed_forward(hidden_states,
                                                       attention_mask)
        else:
            if get_key_value:
                presents = []
            for index in range(self.num_layers):
                layer = self._get_layer(index)
                past = None
                if layer_past is not None:
                    past = layer_past[index]
                hidden_states = layer(hidden_states,
                                      attention_mask,
                                      layer_past=past,
                                      get_key_value=get_key_value)
                if get_key_value:
                    hidden_states, present = hidden_states
                    presents.append(present)
        
        # Final layer norm.
        if mpu.is_pipeline_last_stage():
            # Reverting data format change [s b h] --> [b s h].
            hidden_states = hidden_states.transpose(0, 1).contiguous()
            output = self.final_layernorm(hidden_states)
        else:
            output = hidden_states
        if get_key_value:
            output = [output, presents]

        return output