def train(model: CodeSearchNN, data_manager: DataManager, languages: List[str], device: torch.device, **kwargs): train_language_seqs = load_language_set_seqs(data_manager, languages, shared.DataSet.TRAIN) valid_language_seqs = load_language_set_seqs(data_manager, languages, shared.DataSet.VALID) train_model(model, train_language_seqs, valid_language_seqs, data_manager, device, **kwargs) test_language_code_seqs, test_language_query_seqs = load_language_set_seqs( data_manager, languages, shared.DataSet.TEST) best_model = data_manager.get_torch_model(model) model.eval() with torch.no_grad(): test_mean_mrr, test_mean_mrr_per_language = evaluate_mrr( best_model, test_language_code_seqs, test_language_query_seqs, device, batch_size=kwargs.get('mrr_eval_batch_size', 1000)) if kwargs['verbose']: print(f'Test MRR: {test_mean_mrr:.4f}') print(f'Test MRR: {test_mean_mrr_per_language}')
def get_query_embedding(model: CodeSearchNN, data_manager: DataManager, query: str, max_query_seq_length: int, device: torch.device): with torch.no_grad(): padded_encoded_query = np_to_torch( pad_encode_query(data_manager, query, max_query_seq_length), device) return torch_gpu_to_np(model.encode_query(padded_encoded_query))[0, :]
def get_ndcg_predictions( queries, model: CodeSearchNN, data_manager: DataManager, device: torch.device, nn_lib: str = 'scikit', n_neighbors: int = 150, search_k: int = -1): predictions = [] for language in shared.LANGUAGES: print(f'Evaluating {language}') evaluation_docs = [{'url': doc['url'], 'identifier': doc['identifier']} for doc in data_manager.get_language_corpus(language, shared.DataSet.ALL)] with torch.no_grad(): query_seqs = prepare_data.pad_encode_seqs( (line.split(' ') for line in queries), shared.QUERY_MAX_SEQ_LENGTH, data_manager.get_query_vocabulary(), preprocessing_tokens.preprocess_query_tokens) query_embeddings = torch_gpu_to_np(model.encode_query(np_to_torch(query_seqs, device))) if nn_lib == 'scikit': nn = NearestNeighbors(n_neighbors=n_neighbors, metric='cosine', n_jobs=-1) nn.fit(data_manager.get_language_code_embeddings(language)) _, nearest_neighbor_indices_per_query = nn.kneighbors(query_embeddings) for query_idx, query in enumerate(queries): for query_nearest_code_idx in nearest_neighbor_indices_per_query[query_idx, :]: predictions.append({ 'query': query, 'language': language, 'identifier': evaluation_docs[query_nearest_code_idx]['identifier'], 'url': evaluation_docs[query_nearest_code_idx]['url'], }) elif nn_lib == 'annoy': annoy_index = data_manager.get_language_annoy_index(get_annoy_index(query_embeddings.shape[1]), language) for query_idx, query in enumerate(queries): nearest_neighbor_indices = annoy_index.get_nns_by_vector( query_embeddings[query_idx, :], n_neighbors, search_k=search_k) for query_nearest_code_idx in nearest_neighbor_indices: predictions.append({ 'query': query, 'language': language, 'identifier': evaluation_docs[query_nearest_code_idx]['identifier'], 'url': evaluation_docs[query_nearest_code_idx]['url'], }) else: raise Exception('Unknown nearest neighbors library.') del evaluation_docs gc.collect() return predictions
def get_code_embedding(model: CodeSearchNN, data_manager: DataManager, code: str, language: str, max_code_seq_length: int, device: torch.device): tokens = extract_tokens_from_blob(code, language) with torch.no_grad(): padded_encoded_query = np_to_torch( pad_encode_code_tokens(data_manager, tokens, language, max_code_seq_length), device) return torch_gpu_to_np( model.encode_code(language, padded_encoded_query))[0, :]
def get_language_mrrs(model: CodeSearchNN, language: str, code_seqs: torch.Tensor, query_seqs: torch.Tensor, batch_size=1000, seed=0) -> List[float]: indices = list(range(code_seqs.shape[0])) random.Random(seed).shuffle(indices) mrrs = [] for idx_chunk in chunked(indices, batch_size): if len(idx_chunk) < batch_size: continue code_embeddings = torch_gpu_to_np(model.encode_code(language, code_seqs[idx_chunk])) query_embeddings = torch_gpu_to_np(model.encode_query(query_seqs[idx_chunk])) distance_matrix = cdist(query_embeddings, code_embeddings, 'cosine') correct_elements = np.expand_dims(np.diag(distance_matrix), axis=-1) ranks = np.sum(distance_matrix <= correct_elements, axis=-1) ranks = ranks[np.invert(np.isnan(ranks)) & (ranks >= 1)] # Make sure we only use valid ranks mrrs.append(float(np.mean(1.0 / ranks))) return mrrs
def batch_encode_code_seqs(model: CodeSearchNN, language: str, code_seqs: np.ndarray, device: torch.device, batch_size=1000): n_seqs = code_seqs.shape[0] code_embeddings = np.zeros((n_seqs, model.embedding_size)) idx = 0 for _ in range((n_seqs // batch_size) + 1): end_idx = min(n_seqs, idx + batch_size) with torch.no_grad(): batch_code_seqs = np_to_torch(code_seqs[idx:end_idx, :], device) code_embeddings[idx:end_idx, :] = torch_gpu_to_np( model.encode_code(language, batch_code_seqs)) idx += batch_size return code_embeddings
def train_model(model: CodeSearchNN, train_language_seqs: Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]], valid_language_seqs: Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]], data_manager: DataManager, device: torch.device, learning_rate=1e-3, batch_size=1000, max_epochs=100, patience=5, mrr_eval_batch_size=1000, verbose=True): optimizer = optim.Adam(model.parameters(), lr=learning_rate) es = EarlyStopping(model, data_manager, patience=patience, verbose=verbose) train_language_code_seqs, train_language_query_seqs = train_language_seqs valid_language_code_seqs, valid_language_query_seqs = valid_language_seqs for epoch in range(max_epochs): if verbose: print(f'=== Epoch {epoch} ===') epoch_start = time.time() model.train() loss_per_batch = [] for batch_language_code_seqs, batch_language_query_seqs in generate_batch( train_language_code_seqs, train_language_query_seqs, batch_size): for language in batch_language_code_seqs.keys(): if batch_language_code_seqs[language] is None: continue batch_language_code_seqs[language] = np_to_torch( batch_language_code_seqs[language], device) batch_language_query_seqs[language] = np_to_torch( batch_language_query_seqs[language], device) optimizer.zero_grad() output = model(batch_language_code_seqs, batch_language_query_seqs) loss = cosine_loss(output, device) loss.backward() optimizer.step() loss_per_batch.append(loss.item()) model.eval() with torch.no_grad(): validation_mean_mrr, validation_mean_mrr_per_language = evaluate_mrr( model, valid_language_code_seqs, valid_language_query_seqs, device, batch_size=mrr_eval_batch_size) if verbose: mean_loss_per_batch = np.mean(loss_per_batch) epoch_duration = time.time() - epoch_start print(f'Duration: {epoch_duration:.1f}s') print( f'Train loss: {mean_loss_per_batch:.4f}, Valid MRR: {validation_mean_mrr:.4f}' ) print( f'Valid MRR per language: {validation_mean_mrr_per_language}') if es(validation_mean_mrr): break