Beispiel #1
0
        #         indices.add_item(index, vector)
        # indices.build(200)
        #
        # for query in queries:
        #     for idx, _ in zip(*query_model(query, model, indices, language)):
        #         predictions.append((query, language, definitions[idx]['identifier'], definitions[idx]['url']))
        # JGD only predict over Python
        break

    df = pd.DataFrame(predictions,
                      columns=['query', 'language', 'identifier', 'url'])
    df.to_csv(predictions_csv, index=False)

    if run_id:
        print('Uploading predictions to W&B')
        # upload model predictions CSV file to W&B

        # we checked that there are three path components above
        entity, project, name = args_wandb_run_id.split('/')

        # make sure the file is in our cwd, with the correct name
        predictions_base_csv = "model_predictions.csv"
        shutil.copyfile(predictions_csv, predictions_base_csv)

        # Using internal wandb API. TODO: Update when available as a public API
        internal_api = InternalApi()
        internal_api.push([predictions_base_csv],
                          run=name,
                          entity=entity,
                          project=project)
Beispiel #2
0
def run(args, tag_in_vcs=False) -> None:
    args_wandb_run_id = args["--wandb_run_id"]
    if args_wandb_run_id is not None:
        entity, project, name = args_wandb_run_id.split("/")
        os.environ["WANDB_RUN_ID"] = name
        os.environ["WANDB_RESUME"] = "must"

        wandb_api = wandb.Api()
        # retrieve saved model from W&B for this run
        logger.info("Fetching run from W&B...")
        try:
            wandb_api.run(args_wandb_run_id)
        except wandb.CommError:
            logger.error(
                f"ERROR: Problem querying W&B for wandb_run_id: {args_wandb_run_id}",
                file=sys.stderr)
            sys.exit(1)

    else:
        os.environ["WANDB_MODE"] = "dryrun"

    logger.debug("Building Training Context")
    training_ctx: CodeSearchTrainingContext
    restore_dir = args["--restore"]
    logger.info(f"Restoring Training Context from directory{restore_dir}")
    training_ctx = CodeSearchTrainingContext.build_context_from_dir(
        restore_dir)

    queries = pd.read_csv(training_ctx.queries_file)
    queries = list(map(lambda q: f"<qy> {q}", queries["query"].values))
    queries_tokens, queries_masks = training_ctx.tokenize_query_sentences(
        queries,
        max_length=training_ctx.
        conf["dataset.common_params.query_max_num_tokens"])
    logger.info(f"queries: {queries}")

    training_ctx.eval_mode()
    with torch.no_grad():
        query_embeddings = (training_ctx.encode_query(
            query_tokens=torch.tensor(queries_tokens, dtype=torch.long).to(
                training_ctx.device),
            query_tokens_mask=torch.tensor(queries_masks, dtype=torch.long).to(
                training_ctx.device),
        ).cpu().numpy())
        logger.info(f"query_embeddings: {query_embeddings.shape}")

        topk = 100
        language_token = "<lg>"
        for lang_idx, language in enumerate(
            ("python", "go", "javascript", "java", "php", "ruby")
                # ("php", "ruby")
        ):  # in enumerate(("python", "go", "javascript", "java", "php", "ruby")):
            predictions = []
            # (codes_encoded_df, codes_masks_df, definitions) = get_language_defs(language, training_ctx, language_token)

            code_embeddings, definitions = compute_code_encodings_from_defs(
                language, training_ctx, language_token, batch_length=512)
            logger.info(
                f"Building Annoy Index of length {len(code_embeddings.values[0])}"
            )
            indices: AnnoyIndex = AnnoyIndex(len(code_embeddings.values[0]),
                                             "angular")
            # idx = 0
            for index, emb in enumerate(tqdm(code_embeddings.values)):
                indices.add_item(index, emb)
            indices.build(10)

            for i, (query, query_embedding) in enumerate(
                    tqdm(zip(queries, query_embeddings))):
                idxs, distances = indices.get_nns_by_vector(
                    query_embedding, topk, include_distances=True)
                for idx2, _ in zip(idxs, distances):
                    predictions.append(
                        (query, language, definitions.iloc[idx2]["identifier"],
                         definitions.iloc[idx2]["url"]))

            logger.info(f"predictions {predictions[0]}")

            df = pd.DataFrame(
                predictions,
                columns=["query", "language", "identifier", "url"])
            # BUT WHY DOESNT IT WORK AS EXPECTED????
            df["query"] = df["query"].str.replace("<qy> ", "")
            df["identifier"] = df["identifier"].str.replace(",", "")
            df["identifier"] = df["identifier"].str.replace('"', "")
            df["identifier"] = df["identifier"].str.replace(";", "")
            df.to_csv(
                training_ctx.output_dir /
                f"model_predictions_{training_ctx.training_tokenizer_type}.csv",
                index=False,
                header=True if lang_idx == 0 else False,
                # mode="w" if lang_idx == 0 else "a",
                mode="a",
            )
            # Free memory
            del code_embeddings
            del definitions
            del predictions

    if args_wandb_run_id is not None:
        logger.info("Uploading predictions to W&B")
        # upload model predictions CSV file to W&B

        entity, project, name = args_wandb_run_id.split("/")

        # make sure the file is in our cwd, with the correct name
        predictions_csv = training_ctx.output_dir / f"model_predictions_{training_ctx.training_tokenizer_type}.csv"
        predictions_base_csv = "model_predictions.csv"
        shutil.copyfile(predictions_csv, predictions_base_csv)

        # Using internal wandb API. TODO: Update when available as a public API
        internal_api = InternalApi()
        internal_api.push([predictions_base_csv],
                          run=name,
                          entity=entity,
                          project=project)