예제 #1
0
def create_input_data(config, seed=None):
    if seed is not None:
        np.random.seed(seed)
    rand_mask_idx_list = create_bigbird_rand_mask_idx_list(
        config["num_layers"], config["seq_len"], config["seq_len"],
        config["nhead"], config["block_size"], config["window_size"],
        config["num_global_blocks"], config["num_rand_blocks"], config["seed"])
    input_ids = np.random.randint(
        low=0,
        high=config['vocab_size'],
        size=(config["batch_size"], config["seq_len"]))
    num_to_predict = int(config["seq_len"] * 0.15)
    masked_lm_positions = np.random.choice(
        config["seq_len"], (config["batch_size"], num_to_predict),
        replace=False)
    masked_lm_positions = np.sort(masked_lm_positions)
    pred_padding_len = config["seq_len"] - num_to_predict
    temp_masked_lm_positions = np.full(
        masked_lm_positions.size, 0, dtype=np.int32)
    mask_token_num = 0
    for i, x in enumerate(masked_lm_positions):
        for j, pos in enumerate(x):
            temp_masked_lm_positions[mask_token_num] = i * config[
                "seq_len"] + pos
            mask_token_num += 1
    masked_lm_positions = temp_masked_lm_positions
    return rand_mask_idx_list, input_ids, masked_lm_positions
예제 #2
0
    def _collate_data(data, stack_fn=Stack()):
        # Data Fields: input_ids, segment_ids, masked_lm_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels
        num_fields = len(data[0])
        out = [None] * num_fields

        for i in [0, 1, 5]:
            out[i] = stack_fn([x[i] for x in data])
        batch_size, seq_length = out[0].shape
        size = num_mask = sum(len(x[2]) for x in data)
        out[2] = np.full(size, 0, dtype=np.int32)
        # masked_lm_labels
        out[3] = np.full([size, 1], -1, dtype=np.int64)
        # masked weight
        out[4] = np.full([size], 0, dtype="float32")
        # # Organize as a 1D tensor for gather or use gather_nd
        mask_token_num = 0
        for i, x in enumerate(data):
            for j, pos in enumerate(x[2]):
                out[2][mask_token_num] = i * seq_length + pos
                out[3][mask_token_num] = x[3][j]
                out[4][mask_token_num] = x[4][j]
                mask_token_num += 1
        out.append(np.asarray([mask_token_num], dtype=np.float32))
        seq_len = len(out[0][0])
        rand_mask_idx_list = create_bigbird_rand_mask_idx_list(
            config["num_layers"], seq_len, seq_len, config["nhead"],
            config["block_size"], config["window_size"],
            config["num_global_blocks"], config["num_rand_blocks"],
            config["seed"])
        out.extend(rand_mask_idx_list)
        return out
예제 #3
0
 def _collate_data(data, stack_fn=Stack()):
     num_fields = len(data[0])
     out = [None] * num_fields
     out[0] = stack_fn([_tokenize(x['text']) for x in data])
     out[1] = stack_fn([x['label'] for x in data])
     seq_len = len(out[0][0])
     # Construct the random attention mask for the random attention
     rand_mask_idx_list = create_bigbird_rand_mask_idx_list(
         config["num_layers"], seq_len, seq_len, config["nhead"],
         config["block_size"], config["window_size"],
         config["num_global_blocks"], config["num_rand_blocks"],
         config["seed"])
     out.extend(rand_mask_idx_list)
     return out
예제 #4
0
def collect_data(samples, dataset, config):
    stack_fn = Stack(dtype="int64" if dataset.label_list else "float32")
    stack_fn1 = Stack()

    num_fields = len(samples[0])
    out = [None] * num_fields
    out[0] = stack_fn1([x[0] for x in samples])  # input_ids
    out[1] = stack_fn1([x[1] for x in samples])  # token_type_ids
    if num_fields >= 2:
        out[2] = stack_fn(x[2] for x in samples)  # labels
    seq_len = len(out[0][0])
    # Construct the random attention mask for the random attention
    rand_mask_idx_list = create_bigbird_rand_mask_idx_list(
        config["num_layers"], seq_len, seq_len, config["nhead"],
        config["block_size"], config["window_size"],
        config["num_global_blocks"], config["num_rand_blocks"], config["seed"])
    out.extend(rand_mask_idx_list)
    return out