示例#1
0
def train(model: CodeSearchNN, data_manager: DataManager, languages: List[str],
          device: torch.device, **kwargs):
    train_language_seqs = load_language_set_seqs(data_manager, languages,
                                                 shared.DataSet.TRAIN)
    valid_language_seqs = load_language_set_seqs(data_manager, languages,
                                                 shared.DataSet.VALID)

    train_model(model, train_language_seqs, valid_language_seqs, data_manager,
                device, **kwargs)

    test_language_code_seqs, test_language_query_seqs = load_language_set_seqs(
        data_manager, languages, shared.DataSet.TEST)

    best_model = data_manager.get_torch_model(model)
    model.eval()
    with torch.no_grad():
        test_mean_mrr, test_mean_mrr_per_language = evaluate_mrr(
            best_model,
            test_language_code_seqs,
            test_language_query_seqs,
            device,
            batch_size=kwargs.get('mrr_eval_batch_size', 1000))

        if kwargs['verbose']:
            print(f'Test MRR: {test_mean_mrr:.4f}')
            print(f'Test MRR: {test_mean_mrr_per_language}')
示例#2
0
def get_query_embedding(model: CodeSearchNN, data_manager: DataManager,
                        query: str, max_query_seq_length: int,
                        device: torch.device):
    with torch.no_grad():
        padded_encoded_query = np_to_torch(
            pad_encode_query(data_manager, query, max_query_seq_length),
            device)
        return torch_gpu_to_np(model.encode_query(padded_encoded_query))[0, :]
示例#3
0
def get_ndcg_predictions(
        queries,
        model: CodeSearchNN,
        data_manager: DataManager,
        device: torch.device,
        nn_lib: str = 'scikit',
        n_neighbors: int = 150,
        search_k: int = -1):
    predictions = []
    for language in shared.LANGUAGES:
        print(f'Evaluating {language}')

        evaluation_docs = [{'url': doc['url'], 'identifier': doc['identifier']}
                           for doc in data_manager.get_language_corpus(language, shared.DataSet.ALL)]

        with torch.no_grad():
            query_seqs = prepare_data.pad_encode_seqs(
                (line.split(' ') for line in queries),
                shared.QUERY_MAX_SEQ_LENGTH,
                data_manager.get_query_vocabulary(),
                preprocessing_tokens.preprocess_query_tokens)
            query_embeddings = torch_gpu_to_np(model.encode_query(np_to_torch(query_seqs, device)))

        if nn_lib == 'scikit':
            nn = NearestNeighbors(n_neighbors=n_neighbors, metric='cosine', n_jobs=-1)
            nn.fit(data_manager.get_language_code_embeddings(language))
            _, nearest_neighbor_indices_per_query = nn.kneighbors(query_embeddings)

            for query_idx, query in enumerate(queries):
                for query_nearest_code_idx in nearest_neighbor_indices_per_query[query_idx, :]:
                    predictions.append({
                        'query': query,
                        'language': language,
                        'identifier': evaluation_docs[query_nearest_code_idx]['identifier'],
                        'url': evaluation_docs[query_nearest_code_idx]['url'],
                    })
        elif nn_lib == 'annoy':
            annoy_index = data_manager.get_language_annoy_index(get_annoy_index(query_embeddings.shape[1]), language)
            for query_idx, query in enumerate(queries):
                nearest_neighbor_indices = annoy_index.get_nns_by_vector(
                    query_embeddings[query_idx, :], n_neighbors, search_k=search_k)

                for query_nearest_code_idx in nearest_neighbor_indices:
                    predictions.append({
                        'query': query,
                        'language': language,
                        'identifier': evaluation_docs[query_nearest_code_idx]['identifier'],
                        'url': evaluation_docs[query_nearest_code_idx]['url'],
                    })
        else:
            raise Exception('Unknown nearest neighbors library.')

        del evaluation_docs
        gc.collect()

    return predictions
示例#4
0
def get_code_embedding(model: CodeSearchNN, data_manager: DataManager,
                       code: str, language: str, max_code_seq_length: int,
                       device: torch.device):
    tokens = extract_tokens_from_blob(code, language)
    with torch.no_grad():
        padded_encoded_query = np_to_torch(
            pad_encode_code_tokens(data_manager, tokens, language,
                                   max_code_seq_length), device)
        return torch_gpu_to_np(
            model.encode_code(language, padded_encoded_query))[0, :]
示例#5
0
def get_language_mrrs(model: CodeSearchNN,
                      language: str,
                      code_seqs: torch.Tensor,
                      query_seqs: torch.Tensor,
                      batch_size=1000,
                      seed=0) -> List[float]:
    indices = list(range(code_seqs.shape[0]))
    random.Random(seed).shuffle(indices)

    mrrs = []
    for idx_chunk in chunked(indices, batch_size):
        if len(idx_chunk) < batch_size:
            continue

        code_embeddings = torch_gpu_to_np(model.encode_code(language, code_seqs[idx_chunk]))
        query_embeddings = torch_gpu_to_np(model.encode_query(query_seqs[idx_chunk]))

        distance_matrix = cdist(query_embeddings, code_embeddings, 'cosine')
        correct_elements = np.expand_dims(np.diag(distance_matrix), axis=-1)
        ranks = np.sum(distance_matrix <= correct_elements, axis=-1)
        ranks = ranks[np.invert(np.isnan(ranks)) & (ranks >= 1)]  # Make sure we only use valid ranks
        mrrs.append(float(np.mean(1.0 / ranks)))

    return mrrs
示例#6
0
def batch_encode_code_seqs(model: CodeSearchNN,
                           language: str,
                           code_seqs: np.ndarray,
                           device: torch.device,
                           batch_size=1000):
    n_seqs = code_seqs.shape[0]
    code_embeddings = np.zeros((n_seqs, model.embedding_size))

    idx = 0
    for _ in range((n_seqs // batch_size) + 1):
        end_idx = min(n_seqs, idx + batch_size)

        with torch.no_grad():
            batch_code_seqs = np_to_torch(code_seqs[idx:end_idx, :], device)
            code_embeddings[idx:end_idx, :] = torch_gpu_to_np(
                model.encode_code(language, batch_code_seqs))

        idx += batch_size

    return code_embeddings
示例#7
0
def train_model(model: CodeSearchNN,
                train_language_seqs: Tuple[Dict[str, np.ndarray],
                                           Dict[str, np.ndarray]],
                valid_language_seqs: Tuple[Dict[str, np.ndarray],
                                           Dict[str, np.ndarray]],
                data_manager: DataManager,
                device: torch.device,
                learning_rate=1e-3,
                batch_size=1000,
                max_epochs=100,
                patience=5,
                mrr_eval_batch_size=1000,
                verbose=True):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    es = EarlyStopping(model, data_manager, patience=patience, verbose=verbose)

    train_language_code_seqs, train_language_query_seqs = train_language_seqs
    valid_language_code_seqs, valid_language_query_seqs = valid_language_seqs

    for epoch in range(max_epochs):
        if verbose:
            print(f'=== Epoch {epoch} ===')
        epoch_start = time.time()

        model.train()

        loss_per_batch = []
        for batch_language_code_seqs, batch_language_query_seqs in generate_batch(
                train_language_code_seqs, train_language_query_seqs,
                batch_size):
            for language in batch_language_code_seqs.keys():
                if batch_language_code_seqs[language] is None:
                    continue

                batch_language_code_seqs[language] = np_to_torch(
                    batch_language_code_seqs[language], device)
                batch_language_query_seqs[language] = np_to_torch(
                    batch_language_query_seqs[language], device)

            optimizer.zero_grad()
            output = model(batch_language_code_seqs, batch_language_query_seqs)
            loss = cosine_loss(output, device)
            loss.backward()
            optimizer.step()

            loss_per_batch.append(loss.item())

        model.eval()
        with torch.no_grad():
            validation_mean_mrr, validation_mean_mrr_per_language = evaluate_mrr(
                model,
                valid_language_code_seqs,
                valid_language_query_seqs,
                device,
                batch_size=mrr_eval_batch_size)

        if verbose:
            mean_loss_per_batch = np.mean(loss_per_batch)
            epoch_duration = time.time() - epoch_start
            print(f'Duration: {epoch_duration:.1f}s')
            print(
                f'Train loss: {mean_loss_per_batch:.4f}, Valid MRR: {validation_mean_mrr:.4f}'
            )
            print(
                f'Valid MRR per language: {validation_mean_mrr_per_language}')

        if es(validation_mean_mrr):
            break