Esempio n. 1
0
def replace_one_token(
    predictor: Predictor,
    instances: List[Instance],
    reduction_field_name: str,
    gradient_field_name: str,
    n_beams: List[int],
    indices: List[List[int]],
    replaced_indices: List[List[int]],
    embedding_weight: np.ndarray,
    index_to_token: Dict[int, str],
    max_beam_size: int = 5,
    ignore_tokens: List[str] = ['@@NULL@@'],
):
    """
    remove one token from each example.
    each example branches out to at most max_beam_size new beams.
    we do not do beam verification here.

    batch structure:
    > example 0 beam 1
    > example 0 beam 2  # n_beams[0] = 2
    > example 1 beam 1  # n_beams[1] = 1
    > example 2 beam 1
    > example 2 beam 2  # n_beams[2] = 2
    >                   # n_beams[3] = 0

    """
    n_examples = len(n_beams)  # not batch size!

    if 'label' not in instances[0].fields:
        outputs = predictor.predict_batch_instance(instances)
        instances = [
            predictor.predictions_to_labeled_instances(i, o)[0]
            for i, o in zip(instances, outputs)
        ]

    # one forward-backward pass to get the score of each token in the batch
    gradients, outputs = predictor.get_gradients(instances)
    grads = gradients[gradient_field_name]
    hotflip_grad = np.einsum('bld,kd->blk', grads, embedding_weight)
    sign = -1

    # beams of example_idx: batch[start: start + n_beams[example_idx]]
    start = 0
    new_instances = []
    new_n_beams = [0 for _ in range(n_examples)]
    new_indices = []
    new_replaced_indices = []
    current_lengths = [
        real_sequence_length(x[reduction_field_name], ignore_tokens)
        for x in instances
    ]

    for example_idx in range(n_examples):
        """
        for each example_idx, current beams -> future beams
        1. find beam-level reduction candidates
        2. merge and sort them to get example-level reduction candidates
        """
        # skip if example_idx exited the search
        if n_beams[example_idx] == 0:
            continue

        # find beam-level candidates
        candidates = []  # (batch_index i, token j, replacement k)
        for i in range(start, start + n_beams[example_idx]):
            field = instances[i][reduction_field_name]
            # argsort the flattened scores
            indices_sorted = np.argsort(sign * hotflip_grad[i].ravel())
            # unravel into original shape
            indices_sorted = np.unravel_index(indices_sorted,
                                              hotflip_grad[i].shape)
            indices_sorted = np.stack(indices_sorted, 1)
            beam_candidates = [
                (i, j, k) for j, k in indices_sorted
                if (j < field.sequence_length()
                    and field.tokens[j].text not in ignore_tokens)
            ]
            candidates += beam_candidates[:max_beam_size]

        # no beam-level candidate found, skip
        if len(candidates) == 0:
            start += n_beams[example_idx]
            continue

        # gather scores of all example-level candidates
        # sort them to get example-level candidates
        candidates = np.asarray(candidates)
        scores = hotflip_grad[candidates[:, 0], candidates[:, 1],
                              candidates[:, 2]]
        candidate_scores = sorted(zip(candidates, scores),
                                  key=lambda x: sign * x[1])
        candidates = [c for c, s in candidate_scores[:max_beam_size]]

        # each candidate should be a valid token in the beam it belongs
        assert all(j < current_lengths[i] for i, j, k in candidates)

        for i, j, k in candidates:
            new_instance = deepcopy(instances[i])
            new_instance[reduction_field_name].tokens = (
                new_instance[reduction_field_name].tokens[0:j] +
                [Token(index_to_token[k])] +
                new_instance[reduction_field_name].tokens[j + 1:])
            new_instance.indexed = False

            new_n_beams[example_idx] += 1
            new_instances.append(new_instance)
            new_replaced_indices.append(replaced_indices[i] + [indices[i][j]])
            new_indices.append(indices[i])

        # move starting position to next example
        start += n_beams[example_idx]

    return new_instances, new_n_beams, new_indices, new_replaced_indices
Esempio n. 2
0
def remove_one_token(
        predictor: Predictor,
        instances: List[Instance],
        reduction_field_name: str,
        gradient_field_name: str,
        n_beams: List[int],
        indices: List[List[int]],
        removed_indices: List[List[int]],
        token_id_field_name: str = None,
        embedding_weight: np.ndarray = None,
        max_beam_size: int = 5,
        min_sequence_length: int = 1,
        ignore_tokens: List[str] = ['@@NULL@@'],
):
    """
    remove one token from each example.
    each example branches out to at most max_beam_size new beams.
    we do not do beam verification here.

    batch structure:
    > example 0 beam 1
    > example 0 beam 2  # n_beams[0] = 2
    > example 1 beam 1  # n_beams[1] = 1
    > example 2 beam 1
    > example 2 beam 2  # n_beams[2] = 2
    >                   # n_beams[3] = 0

    """
    n_examples = len(n_beams)  # not batch size!

    if 'label' not in instances[0].fields:
        outputs = predictor.predict_batch_instance(instances)
        instances = [predictor.predictions_to_labeled_instances(i, o)[0]
                     for i, o in zip(instances, outputs)]

    # one forward-backward pass to get the score of each token in the batch
    gradients, outputs = predictor.get_gradients(instances)
    grads = gradients[gradient_field_name]

    if embedding_weight:
        token_ids = outputs[token_id_field_name].cpu().numpy()
        hotflip_grad = np.einsum('bld,kd->blk', grads, embedding_weight)
        onehot_grad = np.take(hotflip_grad, token_ids)
    else:
        onehot_grad = np.einsum('bld,bld->bl', grads, grads)

    # beams of example_idx: batch[start: start + n_beams[example_idx]]
    start = 0
    new_instances = []
    new_n_beams = [0 for _ in range(n_examples)]
    new_indices = []
    new_removed_indices = []
    current_lengths = [real_sequence_length(x[reduction_field_name], ignore_tokens)
                       for x in instances]

    for example_idx in range(n_examples):
        """
        for each example_idx, current beams -> future beams
        1. find beam-level reduction candidates
        2. merge and sort them to get example-level reduction candidates
        """
        # skip if example_idx exited the search
        if n_beams[example_idx] == 0:
            continue

        # find beam-level candidates
        candidates = []  # (batch_index i, token j)
        for i in range(start, start + n_beams[example_idx]):
            if current_lengths[i] <= min_sequence_length:
                # nothing to reduce
                continue
            field = instances[i][reduction_field_name]
            beam_candidates = [
                (i, j) for j in np.argsort(- onehot_grad[i])
                if (
                    j < field.sequence_length()
                    and field.tokens[j].text not in ignore_tokens
                )
            ]
            candidates += beam_candidates[:max_beam_size]

        # no beam-level candidate found, skip
        if len(candidates) == 0:
            start += n_beams[example_idx]
            continue

        # gather scores of all example-level candidates
        # sort them to get example-level candidates
        candidates = np.asarray(candidates)
        scores = onehot_grad[candidates[:, 0], candidates[:, 1]]
        candidate_scores = sorted(zip(candidates, scores), key=lambda x: -x[1])
        candidates = [c for c, s in candidate_scores[:max_beam_size]]

        # each candidate should be a valid token in the beam it belongs
        assert all(j < current_lengths[i] for i, j in candidates)

        for i, j in candidates:
            new_instance = deepcopy(instances[i])
            new_instance[reduction_field_name].tokens = (
                new_instance[reduction_field_name].tokens[0: j]
                + new_instance[reduction_field_name].tokens[j + 1:]
            )
            new_instance.indexed = False

            new_n_beams[example_idx] += 1
            new_instances.append(new_instance)
            new_removed_indices.append(removed_indices[i] + [indices[i][j]])
            new_indices.append(indices[i][:j] + indices[i][j + 1:])

        # move starting position to next example
        start += n_beams[example_idx]

    return new_instances, new_n_beams, new_indices, new_removed_indices