def replace_one_token( predictor: Predictor, instances: List[Instance], reduction_field_name: str, gradient_field_name: str, n_beams: List[int], indices: List[List[int]], replaced_indices: List[List[int]], embedding_weight: np.ndarray, index_to_token: Dict[int, str], max_beam_size: int = 5, ignore_tokens: List[str] = ['@@NULL@@'], ): """ remove one token from each example. each example branches out to at most max_beam_size new beams. we do not do beam verification here. batch structure: > example 0 beam 1 > example 0 beam 2 # n_beams[0] = 2 > example 1 beam 1 # n_beams[1] = 1 > example 2 beam 1 > example 2 beam 2 # n_beams[2] = 2 > # n_beams[3] = 0 """ n_examples = len(n_beams) # not batch size! if 'label' not in instances[0].fields: outputs = predictor.predict_batch_instance(instances) instances = [ predictor.predictions_to_labeled_instances(i, o)[0] for i, o in zip(instances, outputs) ] # one forward-backward pass to get the score of each token in the batch gradients, outputs = predictor.get_gradients(instances) grads = gradients[gradient_field_name] hotflip_grad = np.einsum('bld,kd->blk', grads, embedding_weight) sign = -1 # beams of example_idx: batch[start: start + n_beams[example_idx]] start = 0 new_instances = [] new_n_beams = [0 for _ in range(n_examples)] new_indices = [] new_replaced_indices = [] current_lengths = [ real_sequence_length(x[reduction_field_name], ignore_tokens) for x in instances ] for example_idx in range(n_examples): """ for each example_idx, current beams -> future beams 1. find beam-level reduction candidates 2. merge and sort them to get example-level reduction candidates """ # skip if example_idx exited the search if n_beams[example_idx] == 0: continue # find beam-level candidates candidates = [] # (batch_index i, token j, replacement k) for i in range(start, start + n_beams[example_idx]): field = instances[i][reduction_field_name] # argsort the flattened scores indices_sorted = np.argsort(sign * hotflip_grad[i].ravel()) # unravel into original shape indices_sorted = np.unravel_index(indices_sorted, hotflip_grad[i].shape) indices_sorted = np.stack(indices_sorted, 1) beam_candidates = [ (i, j, k) for j, k in indices_sorted if (j < field.sequence_length() and field.tokens[j].text not in ignore_tokens) ] candidates += beam_candidates[:max_beam_size] # no beam-level candidate found, skip if len(candidates) == 0: start += n_beams[example_idx] continue # gather scores of all example-level candidates # sort them to get example-level candidates candidates = np.asarray(candidates) scores = hotflip_grad[candidates[:, 0], candidates[:, 1], candidates[:, 2]] candidate_scores = sorted(zip(candidates, scores), key=lambda x: sign * x[1]) candidates = [c for c, s in candidate_scores[:max_beam_size]] # each candidate should be a valid token in the beam it belongs assert all(j < current_lengths[i] for i, j, k in candidates) for i, j, k in candidates: new_instance = deepcopy(instances[i]) new_instance[reduction_field_name].tokens = ( new_instance[reduction_field_name].tokens[0:j] + [Token(index_to_token[k])] + new_instance[reduction_field_name].tokens[j + 1:]) new_instance.indexed = False new_n_beams[example_idx] += 1 new_instances.append(new_instance) new_replaced_indices.append(replaced_indices[i] + [indices[i][j]]) new_indices.append(indices[i]) # move starting position to next example start += n_beams[example_idx] return new_instances, new_n_beams, new_indices, new_replaced_indices
def remove_one_token( predictor: Predictor, instances: List[Instance], reduction_field_name: str, gradient_field_name: str, n_beams: List[int], indices: List[List[int]], removed_indices: List[List[int]], token_id_field_name: str = None, embedding_weight: np.ndarray = None, max_beam_size: int = 5, min_sequence_length: int = 1, ignore_tokens: List[str] = ['@@NULL@@'], ): """ remove one token from each example. each example branches out to at most max_beam_size new beams. we do not do beam verification here. batch structure: > example 0 beam 1 > example 0 beam 2 # n_beams[0] = 2 > example 1 beam 1 # n_beams[1] = 1 > example 2 beam 1 > example 2 beam 2 # n_beams[2] = 2 > # n_beams[3] = 0 """ n_examples = len(n_beams) # not batch size! if 'label' not in instances[0].fields: outputs = predictor.predict_batch_instance(instances) instances = [predictor.predictions_to_labeled_instances(i, o)[0] for i, o in zip(instances, outputs)] # one forward-backward pass to get the score of each token in the batch gradients, outputs = predictor.get_gradients(instances) grads = gradients[gradient_field_name] if embedding_weight: token_ids = outputs[token_id_field_name].cpu().numpy() hotflip_grad = np.einsum('bld,kd->blk', grads, embedding_weight) onehot_grad = np.take(hotflip_grad, token_ids) else: onehot_grad = np.einsum('bld,bld->bl', grads, grads) # beams of example_idx: batch[start: start + n_beams[example_idx]] start = 0 new_instances = [] new_n_beams = [0 for _ in range(n_examples)] new_indices = [] new_removed_indices = [] current_lengths = [real_sequence_length(x[reduction_field_name], ignore_tokens) for x in instances] for example_idx in range(n_examples): """ for each example_idx, current beams -> future beams 1. find beam-level reduction candidates 2. merge and sort them to get example-level reduction candidates """ # skip if example_idx exited the search if n_beams[example_idx] == 0: continue # find beam-level candidates candidates = [] # (batch_index i, token j) for i in range(start, start + n_beams[example_idx]): if current_lengths[i] <= min_sequence_length: # nothing to reduce continue field = instances[i][reduction_field_name] beam_candidates = [ (i, j) for j in np.argsort(- onehot_grad[i]) if ( j < field.sequence_length() and field.tokens[j].text not in ignore_tokens ) ] candidates += beam_candidates[:max_beam_size] # no beam-level candidate found, skip if len(candidates) == 0: start += n_beams[example_idx] continue # gather scores of all example-level candidates # sort them to get example-level candidates candidates = np.asarray(candidates) scores = onehot_grad[candidates[:, 0], candidates[:, 1]] candidate_scores = sorted(zip(candidates, scores), key=lambda x: -x[1]) candidates = [c for c, s in candidate_scores[:max_beam_size]] # each candidate should be a valid token in the beam it belongs assert all(j < current_lengths[i] for i, j in candidates) for i, j in candidates: new_instance = deepcopy(instances[i]) new_instance[reduction_field_name].tokens = ( new_instance[reduction_field_name].tokens[0: j] + new_instance[reduction_field_name].tokens[j + 1:] ) new_instance.indexed = False new_n_beams[example_idx] += 1 new_instances.append(new_instance) new_removed_indices.append(removed_indices[i] + [indices[i][j]]) new_indices.append(indices[i][:j] + indices[i][j + 1:]) # move starting position to next example start += n_beams[example_idx] return new_instances, new_n_beams, new_indices, new_removed_indices