def main(): data_manager = get_base_languages_data_manager() device = get_device() model = get_base_language_model_for_evaluation(data_manager, device) query = sys.argv[1] query_embedding = get_query_embedding(model, data_manager, query, shared.QUERY_MAX_SEQ_LENGTH, device) nearest_neighbors_per_language = get_nearest_embedding_neighbors_per_language( data_manager, shared.LANGUAGES, query_embedding, results_per_language=30) for language, nearest_neighbors in nearest_neighbors_per_language.items(): print(language) evaluation_docs = [{ 'url': doc['url'], 'identifier': doc['identifier'], 'code': doc['code'] } for doc in data_manager.get_language_corpus( language, shared.DataSet.ALL)] for idx in nearest_neighbors[0]: print(evaluation_docs[idx]['identifier'], evaluation_docs[idx]['url']) print(evaluation_docs[idx]['code']) print('=' * 20)
def api_repository_search_by_code_view(request, repository_organization, repository_name): if request.method != 'POST': return HttpResponseBadRequest('Invalid HTTP method.') body = json.loads(request.body) code = body.get('code') if not code or len(code.strip()) == 0: return HttpResponseBadRequest('Invalid or missing code.') if len(code) > 4096: return HttpResponseBadRequest('Code too long.') language = body.get('language') repository = get_object_or_404(models.CodeRepository, organization=repository_organization, name=repository_name) repository_languages = [ language.name for language in repository.languages.all() ] if language not in repository_languages: return HttpResponseBadRequest( f'{language} is not a valid repository language.') data_manager = get_repository_data_manager(repository.organization, repository.name) device = get_device() model = get_repository_model_for_evaluation(data_manager, repository_languages, device) code_embedding = get_code_embedding(model, data_manager, code, language, shared.CODE_MAX_SEQ_LENGTH, device) indices, distances = get_nearest_embedding_neighbors_per_language( data_manager, [language], code_embedding, results_per_language=RESULTS_PER_LANGUAGE)[language] code_documents = get_code_documents_from_indices(repository, language, indices) code_documents_with_distances = sort_code_documents_by_distance( sort_code_documents_with_distances_by_index(code_documents, indices, distances)) return JsonResponse({ 'codeDocuments': code_documents_with_distances_as_json(code_documents_with_distances) })
def main(): parser = argparse.ArgumentParser( description='Build code embeddings and AnnoyIndex.') utils.add_bool_arg(parser, 'code-embeddings', default=True) utils.add_bool_arg(parser, 'annoy-index', default=False) args = vars(parser.parse_args()) data_manager = get_base_languages_data_manager() device = get_device() model = get_base_language_model_for_evaluation(data_manager, device) if args['code-embeddings']: build_code_embeddings(model, data_manager, shared.LANGUAGES, device) if args['annoy-index']: build_annoy_indices(data_manager, shared.LANGUAGES)
def prepare_repository(repository_data_manager: DataManager, base_data_manager: DataManager, languages: List[str]): data_preparer = RepositoryDataPreparer(repository_data_manager, base_data_manager, languages, verbose=True) data_preparer.prepare( code_vocabulary_size=shared.CODE_VOCABULARY_SIZE, code_pct_bpe=shared.VOCABULARY_PCT_BPE, query_vocabulary_size=shared.QUERY_VOCABULARY_SIZE, query_pct_bpe=shared.VOCABULARY_PCT_BPE, code_seq_max_length=shared.CODE_MAX_SEQ_LENGTH, query_seq_max_length=shared.QUERY_MAX_SEQ_LENGTH) device = torch_utils.get_device() build_repository_model(repository_data_manager, base_data_manager, languages) train_repository_model(repository_data_manager, languages, device) model = get_repository_model_for_evaluation(repository_data_manager, languages, device) build_code_embeddings(model, repository_data_manager, languages, device) build_annoy_indices(repository_data_manager, languages, n_trees=600)
def main(): parser = argparse.ArgumentParser(description='Evaluate code search model.') utils.add_bool_arg(parser, 'wandb', default=False) args = vars(parser.parse_args()) if args['wandb']: wandb.init(project=shared.ENV['WANDB_PROJECT_NAME'], config=shared.get_wandb_config()) device = get_device() data_manager = get_base_languages_data_manager() model = get_base_language_model_for_evaluation(data_manager, device) queries = get_evaluation_queries() ndcg_predictions = get_ndcg_predictions(queries, model, data_manager, device) df_predictions = pd.DataFrame(ndcg_predictions, columns=['query', 'language', 'identifier', 'url']) save_path = os.path.join(wandb.run.dir, 'model_predictions.csv') if args['wandb'] else '../model_predictions.csv' df_predictions.to_csv(save_path, index=False)
def main(): parser = argparse.ArgumentParser( description='Train code search model from prepared data.') parser.add_argument('--notes', default='') utils.add_bool_arg(parser, 'wandb', default=False) args = vars(parser.parse_args()) data_manager = get_base_languages_data_manager() device = get_device() model = get_base_language_model(device) if args['wandb']: wandb.init(project=shared.ENV['WANDB_PROJECT_NAME'], notes=args['notes'], config=shared.get_wandb_config()) wandb.watch(model) train(model, data_manager, shared.LANGUAGES, device, learning_rate=shared.LEARNING_RATE, batch_size=shared.TRAIN_BATCH_SIZE, verbose=True)
def torch_load(path: str): return torch.load(path, map_location=torch_utils.get_device())
def api_repository_search_view(request, repository_organization, repository_name): if request.method != 'GET': return HttpResponseBadRequest('Invalid HTTP method.') query = request.GET.get('query') if not query or len(query.strip()) == 0: return HttpResponseBadRequest('Invalid or missing query.') if len(query) > 256: return HttpResponseBadRequest('Query too long.') models.QueryLog.objects.create(query=query) repository = get_object_or_404(models.CodeRepository, organization=repository_organization, name=repository_name) repository_languages = [ language.name for language in repository.languages.all() ] language_filter_match = FILTER_BY_LANGUAGE_REGEX.search(query) if language_filter_match is not None: languages_match = language_filter_match.group(1).split(',') languages = [ language.lower() for language in languages_match if language.lower() in repository_languages ] if len(languages) == 0: return HttpResponseBadRequest( 'No valid languages present in the +language filter.') else: languages = repository_languages cache_key = hashlib.sha1( f'query:{repository_organization}:{repository_name}:{query}'.encode( 'utf-8')).hexdigest() if cache_key in cache: nearest_neighbors_per_language = cache.get(cache_key) else: data_manager = get_repository_data_manager(repository.organization, repository.name) device = get_device() model = get_repository_model_for_evaluation(data_manager, languages, device) query_embedding = get_query_embedding(model, data_manager, get_filterless_query(query), shared.QUERY_MAX_SEQ_LENGTH, device) nearest_neighbors_per_language = get_nearest_embedding_neighbors_per_language( data_manager, languages, query_embedding, results_per_language=RESULTS_PER_LANGUAGE) cache.set(cache_key, nearest_neighbors_per_language, timeout=None) # Never expire code_documents_with_distances = [] for language in languages: indices, distances = nearest_neighbors_per_language[language] code_documents = get_code_documents_from_indices( repository, language, indices) code_documents_with_distances.extend( sort_code_documents_with_distances_by_index( code_documents, indices, distances)) code_documents_with_distances = sort_code_documents_by_distance( code_documents_with_distances) return JsonResponse({ 'codeDocuments': code_documents_with_distances_as_json(code_documents_with_distances) })