Ejemplo n.º 1
0
    def tensorize(self, datapoint: str, return_str_rep: bool = False):
        if self.splitting_kind == "token":
            token_idxs = self.vocabulary.get_id_or_unk(datapoint)
            str_repr = datapoint
        elif self.splitting_kind == "subtoken":
            subtoks = split_identifier_into_parts(datapoint)
            if len(subtoks) == 0:
                subtoks = [Vocabulary.get_unk()]
            token_idxs = self.vocabulary.get_id_or_unk_multiple(subtoks)
        elif self.splitting_kind == "bpe":
            if len(datapoint) == 0:
                datapoint = "<empty>"
            token_idxs = self.vocabulary.get_id_or_unk_for_text(datapoint)
            if return_str_rep:  # Do _not_ compute for efficiency
                str_repr = self.vocabulary.tokenize(datapoint)
        elif self.splitting_kind == "char":
            token_idxs = self.vocabulary.tensorize_str(datapoint)
            if return_str_rep:
                str_repr = datapoint[:self.vocabulary.max_char_length]
        else:
            raise ValueError(
                f'Unrecognized token splitting method "{self.splitting_kind}".'
            )

        if return_str_rep:
            return token_idxs, str_repr
        return token_idxs
Ejemplo n.º 2
0
def _evaluate_f1(best_predictions: List[List[np.ndarray]],
                 best_predictions_probs: List[np.ndarray], vocab: Vocabulary,
                 true_labels: np.ndarray):
    true_labels = clean_target_from_padding(true_labels)
    result_accumulator = PointSuggestionEvaluator()
    unk_id = vocab.get_id_or_unk(vocab.get_unk())

    for x_pred, x_prob, y_target in zip(best_predictions,
                                        best_predictions_probs, true_labels):
        confidences = x_prob.tolist()
        is_exact_prediction = [np.all(pred == y_target) for pred in x_pred]
        precision_recall = [
            token_precision_recall(pred.T, y_target) for pred in x_pred
        ]
        is_unknown_word_predicted = [
            np.all(suggestion == unk_id) for suggestion in x_pred
        ]
        unk_word_accuracy = [
            unk_acc(suggestion.T, y_target, unk_id) for suggestion in x_pred
        ]
        result_accumulator.add_result(confidences, is_exact_prediction,
                                      is_unknown_word_predicted,
                                      precision_recall, unk_word_accuracy)

    return result_accumulator
Ejemplo n.º 3
0
    def load_data_from_sample(cls,
                              encoder_label: str,
                              hyperparameters: Dict[str, Any],
                              metadata: Dict[str, Any],
                              data_to_load: Any,
                              function_name: Optional[str],
                              result_holder: Dict[str, Any],
                              is_test: bool = True) -> bool:
        """
        Saves two versions of both the code and the query: one using the docstring as the query and the other using the
        function-name as the query, and replacing the function name in the code with an out-of-vocab token.
        Sub-tokenizes, converts, and pads both versions, and rejects empty samples.
        """
        # Save the two versions of the code and query:
        data_holder = {
            QueryType.DOCSTRING.value: data_to_load,
            QueryType.FUNCTION_NAME.value: None
        }
        # Skip samples where the function name is very short, because it probably has too little information
        # to be a good search query.
        if not is_test and hyperparameters['fraction_using_func_name'] > 0. and function_name and \
                len(function_name) >= hyperparameters['min_len_func_name_for_query']:
            if encoder_label == 'query':
                # Set the query tokens to the function name, broken up into its sub-tokens:
                data_holder[QueryType.FUNCTION_NAME.
                            value] = split_identifier_into_parts(function_name)
            elif encoder_label == 'code':
                # In the code, replace the function name with the out-of-vocab token everywhere it appears:
                data_holder[QueryType.FUNCTION_NAME.value] = [
                    Vocabulary.get_unk() if token == function_name else token
                    for token in data_to_load
                ]

        # Sub-tokenize, convert, and pad both versions:
        for key, data in data_holder.items():
            if not data:
                result_holder[f'{encoder_label}_tokens_{key}'] = None
                result_holder[f'{encoder_label}_tokens_mask_{key}'] = None
                result_holder[f'{encoder_label}_tokens_length_{key}'] = None
                continue
            if hyperparameters[f'{encoder_label}_use_subtokens']:
                data = cls._to_subtoken_stream(
                    data,
                    mark_subtoken_end=hyperparameters[
                        f'{encoder_label}_mark_subtoken_end'])
            tokens, tokens_mask = \
                convert_and_pad_token_sequence(metadata['token_vocab'], list(data),
                                               hyperparameters[f'{encoder_label}_max_num_tokens'])
            # Note that we share the result_holder with different encoders, and so we need to make our identifiers
            # unique-ish
            result_holder[f'{encoder_label}_tokens_{key}'] = tokens
            result_holder[f'{encoder_label}_tokens_mask_{key}'] = tokens_mask
            result_holder[f'{encoder_label}_tokens_length_{key}'] = int(
                np.sum(tokens_mask))

        if result_holder[f'{encoder_label}_tokens_mask_{QueryType.DOCSTRING.value}'] is None or \
                int(np.sum(result_holder[f'{encoder_label}_tokens_mask_{QueryType.DOCSTRING.value}'])) == 0:
            return False

        return True
Ejemplo n.º 4
0
def get_dataset_from(
        data_dirs: List[RichPath],
        use_func_names: bool = False,
        max_files_per_dir: Optional[int] = None) -> List[Dict[str, Any]]:
    data_files = sorted(
        get_data_files_from_directory(data_dirs, max_files_per_dir))
    data = list(
        chain(*chain(
            list(
                data_pipeline.combined_samples_generator(
                    {data_pipeline.CODE_TOKENS_LABEL: f}))
            for f in data_files)))

    if use_func_names:
        # This task tries to match the function name to the code, by setting the function name as the query
        for sample in data:
            # Replace the query tokens with the function name, broken up into its sub-tokens:
            sample['docstring_tokens'] = split_identifier_into_parts(
                sample['func_name'])

            # In the code, replace the function name with the out-of-vocab token everywhere it appears:
            sample['code_tokens'] = [
                Vocabulary.get_unk() if token == sample['func_name'] else token
                for token in sample['code_tokens']
            ]
    return data
 def get_vocab_extended_nl_token(self, token_id, inp_ids, inp_tokens):
     if token_id < len(self.__nl_vocabulary):
         return self.get_nl_token(token_id)
     elif token_id in inp_ids:
         copy_idx = inp_ids.index(token_id)
         return inp_tokens[copy_idx]
     else:
         return Vocabulary.get_unk()
    def greedy_decode(self, initial_state, encoder_hidden_states, masks,
                      max_out_len, batch_data, device):
        """Greedily generates the output sequence."""
        # Derived from https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/9f6b66f43d2e05175dabcc024f79e1d37a667070/decode_beam.py#L163
        batch_size = initial_state.shape[0]
        decoder_state = initial_state
        decoder_input = torch.tensor(
            [[self.embedding_store.get_nl_id(START)]] * batch_size,
            device=device)

        decoded_batch = np.zeros([batch_size, max_out_len], dtype=np.int64)
        decoded_batch_scores = np.zeros([batch_size, max_out_len])

        for i in range(max_out_len):
            decoder_input_embeddings = self.embedding_store.get_nl_embeddings(
                decoder_input)
            decoder_attention_states, decoder_state, generation_logprobs, copy_logprobs = self.decode(
                decoder_state, decoder_input_embeddings, encoder_hidden_states,
                masks)

            generation_logprobs = generation_logprobs.squeeze(1)
            copy_logprobs = copy_logprobs.squeeze(1)

            prob_scores = torch.zeros([
                generation_logprobs.shape[0],
                generation_logprobs.shape[-1] + copy_logprobs.shape[-1]
            ],
                                      dtype=torch.float32,
                                      device=device)
            prob_scores[:, :generation_logprobs.shape[-1]] = torch.exp(
                generation_logprobs)
            for b in range(generation_logprobs.shape[0]):
                for c, inp_id in enumerate(batch_data.input_ids[b]):
                    prob_scores[b,
                                inp_id] = prob_scores[b, inp_id] + torch.exp(
                                    copy_logprobs[b, c])

            predicted_ids = torch.argmax(prob_scores, dim=-1)
            decoded_batch_scores[:, i] = prob_scores[
                torch.arange(prob_scores.shape[0]), predicted_ids]
            decoded_batch[:, i] = predicted_ids

            unks = torch.ones(predicted_ids.shape[0],
                              dtype=torch.int64,
                              device=device) * self.embedding_store.get_nl_id(
                                  Vocabulary.get_unk())
            decoder_input = torch.where(
                predicted_ids < len(self.embedding_store.nl_vocabulary),
                predicted_ids, unks).unsqueeze(1)
            decoder_state = decoder_state.squeeze(0)

        return decoded_batch, decoded_batch_scores
Ejemplo n.º 7
0
 def unk_token(self) -> str:
     return Vocabulary.get_unk()
Ejemplo n.º 8
0
    def beam_decode(self, initial_state, encoder_hidden_states,
                    code_hidden_states, old_nl_hidden_states, masks,
                    max_out_len, batch_data, code_masks, old_nl_masks, device):
        """Beam search. Generates the top K candidate predictions."""
        batch_size = initial_state.shape[0]
        decoded_batch = [list() for _ in range(batch_size)]
        decoded_batch_scores = np.zeros([batch_size, BEAM_SIZE])

        decoder_input = torch.tensor(
            [[self.embedding_store.get_nl_id(START)]] * batch_size,
            device=device)
        decoder_input = decoder_input.unsqueeze(1)
        decoder_state = initial_state.unsqueeze(1).expand(
            -1, decoder_input.shape[1], -1).reshape(-1,
                                                    initial_state.shape[-1])

        beam_scores = torch.ones([batch_size, 1],
                                 dtype=torch.float32,
                                 device=device)
        beam_status = torch.zeros([batch_size, 1],
                                  dtype=torch.uint8,
                                  device=device)
        beam_predicted_ids = torch.full([batch_size, 1, max_out_len],
                                        self.embedding_store.get_end_id(),
                                        dtype=torch.int64,
                                        device=device)

        for i in range(max_out_len):
            beam_size = decoder_input.shape[1]
            if beam_status[:, 0].sum() == batch_size:
                break

            tiled_encoder_states = encoder_hidden_states.unsqueeze(1).expand(
                -1, beam_size, -1, -1)
            tiled_masks = masks.unsqueeze(1).expand(-1, beam_size, -1, -1)
            tiled_code_hidden_states = code_hidden_states.unsqueeze(1).expand(
                -1, beam_size, -1, -1)
            tiled_code_masks = code_masks.unsqueeze(1).expand(
                -1, beam_size, -1, -1)
            tiled_old_nl_hidden_states = old_nl_hidden_states.unsqueeze(
                1).expand(-1, beam_size, -1, -1)
            tiled_old_nl_masks = old_nl_masks.unsqueeze(1).expand(
                -1, beam_size, -1, -1)

            flat_decoder_input = decoder_input.reshape(-1,
                                                       decoder_input.shape[-1])
            flat_encoder_states = tiled_encoder_states.reshape(
                -1, tiled_encoder_states.shape[-2],
                tiled_encoder_states.shape[-1])
            flat_masks = tiled_masks.reshape(-1, tiled_masks.shape[-2],
                                             tiled_masks.shape[-1])
            flat_code_hidden_states = tiled_code_hidden_states.reshape(
                -1, tiled_code_hidden_states.shape[-2],
                tiled_code_hidden_states.shape[-1])
            flat_code_masks = tiled_code_masks.reshape(
                -1, tiled_code_masks.shape[-2], tiled_code_masks.shape[-1])
            flat_old_nl_hidden_states = tiled_old_nl_hidden_states.reshape(
                -1, tiled_old_nl_hidden_states.shape[-2],
                tiled_old_nl_hidden_states.shape[-1])
            flat_old_nl_masks = tiled_old_nl_masks.reshape(
                -1, tiled_old_nl_masks.shape[-2], tiled_old_nl_masks.shape[-1])

            decoder_input_embeddings = self.embedding_store.get_nl_embeddings(
                flat_decoder_input)
            decoder_attention_states, flat_decoder_state, generation_logprobs, copy_logprobs = self.decode(
                decoder_state, decoder_input_embeddings, flat_encoder_states,
                flat_code_hidden_states, flat_old_nl_hidden_states, flat_masks,
                flat_code_masks, flat_old_nl_masks)

            generation_logprobs = generation_logprobs.squeeze(1)
            copy_logprobs = copy_logprobs.squeeze(1)

            generation_logprobs = generation_logprobs.reshape(
                batch_size, beam_size, generation_logprobs.shape[-1])
            copy_logprobs = copy_logprobs.reshape(batch_size, beam_size,
                                                  copy_logprobs.shape[-1])

            prob_scores = torch.zeros([
                batch_size, beam_size,
                generation_logprobs.shape[-1] + copy_logprobs.shape[-1]
            ],
                                      dtype=torch.float32,
                                      device=device)
            prob_scores[:, :, :generation_logprobs.shape[-1]] = torch.exp(
                generation_logprobs)

            # Factoring in the copy scores
            expanded_token_ids = batch_data.input_ids.unsqueeze(1).expand(
                -1, beam_size, -1)
            prob_scores += scatter_add(src=torch.exp(copy_logprobs),
                                       index=expanded_token_ids,
                                       out=torch.zeros_like(prob_scores))

            top_scores_per_beam, top_indices_per_beam = torch.topk(prob_scores,
                                                                   k=BEAM_SIZE,
                                                                   dim=-1)

            updated_scores = torch.einsum('eb,ebm->ebm', beam_scores,
                                          top_scores_per_beam)
            retained_scores = beam_scores.unsqueeze(-1).expand(
                -1, -1, top_scores_per_beam.shape[-1])

            # Trying to keep at most one ray corresponding to completed beams
            end_mask = (torch.arange(beam_size) == 0).type(
                torch.float32).to(device)
            end_scores = torch.einsum('b,ebm->ebm', end_mask, retained_scores)

            possible_next_scores = torch.where(
                beam_status.unsqueeze(-1) == 1, end_scores, updated_scores)
            possible_next_status = torch.where(
                top_indices_per_beam == self.embedding_store.get_end_id(),
                torch.ones(
                    [batch_size, beam_size, top_scores_per_beam.shape[-1]],
                    dtype=torch.uint8,
                    device=device),
                beam_status.unsqueeze(-1).expand(
                    -1, -1, top_scores_per_beam.shape[-1]))

            possible_beam_predicted_ids = beam_predicted_ids.unsqueeze(
                2).expand(-1, -1, top_scores_per_beam.shape[-1], -1)
            pool_next_scores = possible_next_scores.reshape(batch_size, -1)
            pool_next_status = possible_next_status.reshape(batch_size, -1)
            pool_next_ids = top_indices_per_beam.reshape(batch_size, -1)
            pool_predicted_ids = possible_beam_predicted_ids.reshape(
                batch_size, -1, beam_predicted_ids.shape[-1])

            possible_decoder_state = flat_decoder_state.reshape(
                batch_size, beam_size, flat_decoder_state.shape[-1])
            possible_decoder_state = possible_decoder_state.unsqueeze(
                2).expand(-1, -1, top_scores_per_beam.shape[-1], -1)
            pool_decoder_state = possible_decoder_state.reshape(
                batch_size, -1, possible_decoder_state.shape[-1])

            top_scores, top_indices = torch.topk(pool_next_scores,
                                                 k=BEAM_SIZE,
                                                 dim=-1)
            next_step_ids = torch.gather(pool_next_ids, -1, top_indices)

            decoder_state = torch.gather(
                pool_decoder_state, 1,
                top_indices.unsqueeze(-1).expand(-1, -1,
                                                 pool_decoder_state.shape[-1]))
            decoder_state = decoder_state.reshape(-1, decoder_state.shape[-1])
            beam_status = torch.gather(pool_next_status, -1, top_indices)
            beam_scores = torch.gather(pool_next_scores, -1, top_indices)

            end_tags = torch.full_like(next_step_ids,
                                       self.embedding_store.get_end_id())
            next_step_ids = torch.where(beam_status == 1, end_tags,
                                        next_step_ids)

            beam_predicted_ids = torch.gather(
                pool_predicted_ids, 1,
                top_indices.unsqueeze(-1).expand(-1, -1,
                                                 pool_predicted_ids.shape[-1]))
            beam_predicted_ids[:, :, i] = next_step_ids

            unks = torch.full_like(
                next_step_ids,
                self.embedding_store.get_nl_id(Vocabulary.get_unk()))
            decoder_input = torch.where(
                next_step_ids < len(self.embedding_store.nl_vocabulary),
                next_step_ids, unks).unsqueeze(-1)

        return beam_predicted_ids, beam_scores
    def beam_decode(self, initial_state, encoder_hidden_states, code_hidden_states,
                    old_nl_hidden_states,masks, max_out_len, batch_data, code_masks, old_nl_masks, device):
        """Beam search. Generates the top K candidate predictions."""
        batch_size = initial_state.shape[0]
        decoded_batch = [list() for _ in range(batch_size)]
        decoded_batch_scores = np.zeros([batch_size, BEAM_SIZE])

        for b_idx in range(batch_size):
            beam_scores = torch.ones(BEAM_SIZE, dtype=torch.float32, device=device)
            beam_status = torch.zeros(BEAM_SIZE, dtype=torch.uint8, device=device)
            beam_predicted_ids = [list() for _ in range(BEAM_SIZE)]
            
            decoder_state = initial_state[b_idx].unsqueeze(0)
            decoder_input = torch.tensor([[self.embedding_store.get_nl_id(START)]], device=device)

            for i in range(max_out_len):
                beam_size = decoder_input.shape[0]
                tiled_encoder_states = encoder_hidden_states[b_idx].unsqueeze(0).expand(beam_size, -1, -1)
                tiled_masks = masks[b_idx].expand(beam_size, -1).unsqueeze(1)
                
                tiled_code_encoder_states = code_hidden_states[b_idx].unsqueeze(0).expand(beam_size, -1, -1)
                tiled_old_nl_encoder_states = old_nl_hidden_states[b_idx].unsqueeze(0).expand(beam_size, -1, -1)
                tiled_code_masks = code_masks[b_idx].expand(beam_size, -1).unsqueeze(1)
                tiled_old_nl_masks = old_nl_masks[b_idx].expand(beam_size, -1).unsqueeze(1)

                decoder_input_embeddings = self.embedding_store.get_nl_embeddings(decoder_input)
                decoder_attention_states, decoder_state, generation_logprobs, copy_logprobs = self.decode(decoder_state, decoder_input_embeddings,
                    tiled_encoder_states, tiled_code_encoder_states, tiled_old_nl_encoder_states, tiled_masks,
                    tiled_code_masks, tiled_old_nl_masks)
                
                generation_logprobs = generation_logprobs.squeeze(1)
                copy_logprobs = copy_logprobs.squeeze(1)

                prob_scores = torch.zeros([beam_size,
                    generation_logprobs.shape[-1] + copy_logprobs.shape[-1]], dtype=torch.float32, device=device)
                prob_scores[:, :generation_logprobs.shape[-1]] = torch.exp(generation_logprobs)
                for b in range(beam_size):
                    for c, inp_id in enumerate(batch_data.input_ids[b_idx]):
                        prob_scores[b, inp_id] = prob_scores[b, inp_id] + torch.exp(copy_logprobs[b,c])

                decoder_state = decoder_state.squeeze(0)
                top_scores_per_beam, top_indices_per_beam = torch.topk(prob_scores, k=BEAM_SIZE, dim=-1)
                top_scores_per_beam = top_scores_per_beam.reshape(-1)
                top_indices_per_beam = top_indices_per_beam.reshape(-1)

                full_scores = torch.zeros(beam_size * BEAM_SIZE, dtype=torch.float32, device=device)
                beam_positions = torch.zeros(beam_size * BEAM_SIZE, dtype=torch.int64, device=device)

                for beam_idx in range(beam_size):
                    if beam_status[beam_idx] == 1:
                        idx = beam_idx*beam_size
                        beam_positions[idx] = beam_idx
                        full_scores[idx] = beam_scores[beam_idx]
                        for sub_beam_idx in range(BEAM_SIZE):
                            idx = beam_idx*beam_size + sub_beam_idx
                            beam_positions[idx] = beam_idx
                        continue
                    else:
                        for sub_beam_idx in range(BEAM_SIZE):
                            idx = beam_idx*beam_size + sub_beam_idx
                            beam_positions[idx] = beam_idx
                            full_scores[idx] = beam_scores[beam_idx] * top_scores_per_beam[idx]
                
                # https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/9f6b66f43d2e05175dabcc024f79e1d37a667070/decode_beam.py#L124
                top_scores, top_indices = torch.topk(full_scores, k=BEAM_SIZE, dim=-1)
                new_scores = torch.ones(BEAM_SIZE, dtype=torch.float32, device=device)
                new_status = torch.zeros(BEAM_SIZE, dtype=torch.uint8, device=device)
                new_ids = [list() for _ in range(BEAM_SIZE)]
                next_step_ids = torch.zeros(BEAM_SIZE, dtype=torch.int64, device=device)
                next_decoder_state = torch.zeros([BEAM_SIZE, decoder_state.shape[1]], dtype=torch.float32, device=device)

                for b, pos in enumerate(top_indices):
                    beam_idx = beam_positions[pos]
                    next_decoder_state[b] = decoder_state[beam_idx]
                    if beam_status[beam_idx] == 1:
                        new_scores[b] = beam_scores[beam_idx]
                        new_status[b] = beam_status[beam_idx]
                        new_ids[b] = beam_predicted_ids[beam_idx]
                        next_step_ids[b] = self.embedding_store.get_end_id()
                    else:
                        new_scores[b] = top_scores[b]
                        predicted_id = top_indices_per_beam[pos]
                        new_status[b] = self.embedding_store.get_end_id() == predicted_id
                        new_ids[b] = beam_predicted_ids[beam_idx] + [predicted_id]
                        next_step_ids[b] = predicted_id
                
                unks = torch.ones(
                    next_step_ids.shape[0], dtype=torch.int64, device=device) * self.embedding_store.get_nl_id(Vocabulary.get_unk())
                decoder_input = torch.where(next_step_ids < len(self.embedding_store.nl_vocabulary), next_step_ids, unks).unsqueeze(1)
                decoder_state = next_decoder_state
                beam_scores = new_scores
                beam_status = new_status
                beam_predicted_ids = new_ids
        
            decoded_batch_scores[b_idx] = beam_scores
            decoded_batch[b_idx] = beam_predicted_ids

        return decoded_batch, decoded_batch_scores
 def is_nl_unk(self, id):
     return id == self.__nl_vocabulary.get_id_or_unk(Vocabulary.get_unk())