Пример #1
0
def generate_samples(model, tokenizer, args):
    print (f"generate_samples was called with model {model} \n and tokenizer {tokenizer}")
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            terminate_runs = 0
            print (f"terminate_runs = {terminate_runs}")

            if mpu.get_model_parallel_rank() == 0:
                print ("get_model_parallel_rank() was 0")
#                 raw_text = input("\nContext prompt (stop to exit) >>> ")
                raw_text = "localStorage.getItem("
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
                    context_tokens = tokenizer(raw_text)['input_ids']
                    context_length = len(context_tokens)

                    if context_length >= args.seq_length // 2:
                        print("\nContext length", context_length,
                              "\nPlease give smaller context (half of the sequence length)!")
                        continue
            else:
                print (f"get_model_parallel_rank() was NOT 0 but {mpu.get_model_parallel_rank()}")
                _ = tokenizer("EMPTY TEXT")['input_ids']

            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
            torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

            start_time = time.time()
            print ("generating...")
            generated = generate(
                model, tokenizer, raw_text,
                out_seq_length=args.out_seq_length,
                seq_length=args.seq_length,
                temperature=args.temperature,
                top_k=args.top_k,
                top_p=args.top_p
            )

            if mpu.get_model_parallel_rank() == 0:
                print ("We should clear the terminal and print results...")
                os.system('clear')
                print("\nTime taken: {:.2f}\n".format(time.time() - start_time), flush=True)
                print("\nContext:", raw_text, flush=True)
                print("\nGPT:", generated, flush=True)
            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
def generate_samples(model, tokenizer, args):
    model.eval()
    with torch.no_grad():
        while True:
            torch.distributed.barrier(group=mpu.get_model_parallel_group())
            terminate_runs = 0

            if mpu.get_model_parallel_rank() == 0:
                raw_text = input("\nContext prompt (stop to exit) >>> ")
                while not raw_text:
                    print('Prompt should not be empty!')
                    raw_text = input("\nContext prompt (stop to exit) >>> ")

                if "stop" in raw_text:
                    terminate_runs = 1
                else:
                    context_tokens = tokenizer(raw_text)['input_ids']
                    context_length = len(context_tokens)

                    if context_length >= args.seq_length // 2:
                        print(
                            "\nContext length", context_length,
                            "\nPlease give smaller context (half of the sequence length)!"
                        )
                        continue
            else:
                _ = tokenizer("EMPTY TEXT")['input_ids']

            terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
            torch.distributed.broadcast(terminate_runs_tensor,
                                        mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            terminate_runs = terminate_runs_tensor[0].item()

            if terminate_runs == 1:
                return

            start_time = time.time()
            generated = generate(model,
                                 tokenizer,
                                 raw_text,
                                 out_seq_length=args.out_seq_length,
                                 seq_length=args.seq_length,
                                 temperature=args.temperature,
                                 top_k=args.top_k,
                                 top_p=args.top_p)

            if mpu.get_model_parallel_rank() == 0:
                os.system('clear')
                print("\nTaken time {:.2f}\n".format(time.time() - start_time),
                      flush=True)
                print("\nContext:", raw_text, flush=True)
                print("\nGPT:", generated, flush=True)
            raw_text = None

            torch.distributed.barrier(group=mpu.get_model_parallel_group())
Пример #3
0
def get_train_val_test_data(args):
    """Load the data on rank zero and boradcast number of tokens to all GPUS."""

    (train_data, val_data, test_data) = (None, None, None)

    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        (train_data, val_data, test_data), num_tokens, eod_token, tokenizer = make_gpt3_dataloaders(args)
        before = num_tokens
        after = before
        multiple = args.make_vocab_size_divisible_by * mpu.get_model_parallel_world_size()
        while (after % multiple) != 0:
            after += 1
        print_rank_0(
            '> padded vocab (size: {}) with {} dummy tokens (new size: {})'.format(before, after - before, after))
        print_rank_0('> end-of-document token: {}'.format(eod_token))
        token_counts = torch.cuda.LongTensor(
            [after, eod_token, int(args.do_train), int(args.do_valid), int(args.do_test)])
    else:
        tokenizer = None
        token_counts = torch.cuda.LongTensor([0, 0, 0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(token_counts,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    num_tokens = token_counts[0].item()
    eod_token = token_counts[1].item()
    args.do_train = token_counts[2].item()
    args.do_valid = token_counts[3].item()
    args.do_test = token_counts[4].item()

    return train_data, val_data, test_data, num_tokens, eod_token, tokenizer
Пример #4
0
def generate(model, tokenizer, raw_text, out_seq_length=256, seq_length=512, temperature=1.0, top_k=0, top_p=0.9):
    context_tokens = tokenizer(raw_text)['input_ids']
    context_length = len(context_tokens)
    pad_id = tokenizer.encoder['<pad>']
    if context_length < seq_length:
        context_tokens.extend([pad_id] * (seq_length - context_length))

    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor([context_length])

    torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())

    context_length = context_length_tensor[0].item()

    tokens = context_tokens_tensor
    tokens = tokens.view(1, -1).contiguous()
    tokens = tokens.to(torch.cuda.current_device())
    attention_mask, loss_mask, position_ids = get_masks_and_position_ids(tokens, pad_id, False, False)

    counter = 0
    start_context_length = context_length

    while counter < (start_context_length + out_seq_length):
        logits = model(tokens, position_ids, attention_mask)
        logits = logits[:, context_length - 1, :] / temperature
        logits = top_k_logits(logits, top_k=top_k, top_p=top_p)
        log_probs = torch.nn.functional.softmax(logits, dim=-1)
        prev = torch.multinomial(log_probs, num_samples=1)
        tokens[0, context_length] = prev[0]
        context_length += 1
        if context_length >= seq_length:
            break
        counter += 1

        output_tokens_list = tokens.view(-1).tolist()
        decode_tokens = tokenizer.decode(output_tokens_list)
        decode_tokens = decode_tokens[:decode_tokens.find("<|endoftext|>")]
        token_end = decode_tokens.find("<|endoftext|>")
        if token_end != -1:
            break

    output_tokens_list = tokens.view(-1).tolist()
    decode_tokens = tokenizer.decode(output_tokens_list)
    return decode_tokens[:decode_tokens.find("<|endoftext|>")]
Пример #5
0
 def has_overflow(self, params):
     overflow = self.has_overflow_serial(params)
     # Since each model parallel GPU carries only part of the model,
     # make sure overflow flag is synced across all the model parallel GPUs
     overflow_gpu = torch.cuda.ByteTensor([overflow])
     torch.distributed.all_reduce(overflow_gpu,
                                  op=torch.distributed.ReduceOp.MAX,
                                  group=mpu.get_model_parallel_group())
     overflow = overflow_gpu[0].item()
     return bool(overflow)
Пример #6
0
    def __call__(self, text=None, input_ids=None, labels=None, **kwargs):
        if input_ids is None:
            if text is None:
                text = ""
            input_ids = torch.cuda.LongTensor([self.tokenizer(text)['input_ids']])
        if isinstance(input_ids, list):
            input_ids = torch.cuda.LongTensor(input_ids)
        if isinstance(labels, list):
            labels = torch.cuda.LongTensor(labels)
        res = []
        if labels is not None:
            lbls = labels
        else:
            lbls = [None] * len(input_ids)
        loss = None
        original_context_length = 0
        seq_len = self.seq_len
        for tokens, lbl in zip(input_ids, lbls):
            context_tokens = tokens.tolist()
            context_length = len(context_tokens)
            original_context_length = len(context_tokens)
            
            while context_length > seq_len:
                seq_len += 16
            if context_length < seq_len:
                context_tokens.extend([self.pad_token_id] * (seq_len - context_length))
                if labels is not None:
                    lbl = lbl.tolist()
                    lbl.extend([self.pad_token_id] * (seq_len - context_length))
                    lbl = torch.cuda.LongTensor(lbl)
            if context_length > 2048:
                context_tokens = context_tokens[-2048:]
                if labels is not None:
                    lbl = lbl.tolist()[-2048:]
                    lbl = torch.cuda.LongTensor(lbl)
            context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
            context_length_tensor = torch.cuda.LongTensor([context_length])

            torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())
            torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
                                        group=mpu.get_model_parallel_group())

            # context_length = context_length_tensor[0].item()

            tokens = context_tokens_tensor
            tokens = tokens.view(1, -1).contiguous()
            tokens = tokens.to(torch.cuda.current_device())
            attention_mask, loss_mask, position_ids = get_masks_and_position_ids(tokens, self.pad_token_id, False,
                                                                                 False)
            lm_logits = self.model(tokens, position_ids, attention_mask)
            loss = None
            if labels is not None:
                # Shift so that tokens < n predict n
                shift_logits = lm_logits[..., :-1, :].contiguous()
                shift_labels = lbl[..., 1:].contiguous()
                # Flatten the tokens
                loss_fct = CrossEntropyLoss(ignore_index=self.pad_token_id)
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            res.append((lm_logits, loss))
        logits = torch.cat([x[0] for x in res], dim=0)[:, : original_context_length, :]
        if loss is not None:
            loss = [x[1] for x in res]
        return ModelOutput(logits, loss)