Example #1
0
def verify_same_hashes(index: ExamplesStore, yids, yhashes):
    for yid, yhash in zip(yids, yhashes):
        if index.get_y_set_hash(yid) != yhash:
            for yv in index.get_y_values_for_y_set(yid):
                print(yv)
            print(index.get_y_set_hash(yid))
            raise ValueError("Not matching hash")
Example #2
0
def full_ret_from_example_store(example_store: ExamplesStore,
                                replacers: Replacer,
                                pretrained_checkpoint: str,
                                replace_samples: int = REPLACEMENT_SAMPLES,
                                encoder_name="CM") -> FullRetModel:
    output_size = 200

    # for glove the parseable vocab is huge. Narrow it down
    #query_vocab = make_x_vocab_from_examples(example_store, x_tokenizer, replacers,
    #                                         min_freq=2, replace_samples=2)
    #print(f"vocab len {len(query_vocab)}")

    if encoder_name == "CM":
        (x_tokenizer, query_vocab), y_tokenizer = get_default_tokenizers(
            use_word_piece=True)
        embedder = make_default_query_encoder(x_tokenizer, query_vocab,
                                              output_size,
                                              pretrained_checkpoint)
        embedder.eval()
    elif encoder_name == "BERT":
        from ainix_kernel.models.EncoderDecoder.bertencoders import BertEncoder
        embedder = BertEncoder()
    else:
        raise ValueError(f"Unsupported encoder_name {encoder_name}")
    print(
        f"Number of embedder params: {get_number_of_model_parameters(embedder)}"
    )
    parser = StringParser(example_store.type_context)
    unparser = AstUnparser(example_store.type_context,
                           embedder.get_tokenizer())
    summaries, example_refs, example_splits = [], [], []
    nb_update_fn, finalize_nb_fn = get_nb_learner()
    with torch.no_grad():
        for xval in tqdm(list(example_store.get_all_x_values())):
            # Get the most prefered y text
            all_y_examples = example_store.get_y_values_for_y_set(
                xval.y_set_id)
            most_preferable_y = all_y_examples[0]
            new_summary, new_example_ref = _preproc_example(
                xval,
                most_preferable_y,
                replacers,
                embedder,
                parser,
                unparser,
                nb_update_fn if xval.split == DataSplits.TRAIN else None,
                replace_samples=replace_samples)
            summaries.append(new_summary)
            example_refs.append(new_example_ref)
            example_splits.append(xval.split)
    return FullRetModel(
        embedder=embedder,
        summaries=torch.stack(summaries),
        dataset_splits=torch.tensor(example_splits),
        example_refs=np.array(example_refs),
        choice_models=None  #finalize_nb_fn()
    )
Example #3
0
def _load_single_example(example_dict: Dict, xtype: str, ytype: str,
                         load_index: ExamplesStore, splits: DataSplitter):
    x = example_dict['x']
    if not isinstance(x, list):
        x = [x]
    y = example_dict['y']
    if not isinstance(y, list):
        y = [y]
    x = list(map(str, x))
    y = list(map(str, y))
    load_index.add_yset(x, y, xtype, ytype, default_preferences(len(y)),
                        splits)
def update_latent_store_from_examples(
        model: 'StringTypeTranslateCF', latent_store: LatentStore,
        examples: ExamplesStore, replacer: Replacer, parser: StringParser,
        splits: Optional[Tuple[DataSplits]], unparser: AstUnparser,
        tokenizer: StringTokenizer):
    model.set_in_eval_mode()
    for example in examples.get_all_x_values(splits):
        # TODO multi sampling and average replacers
        x_replaced, y_replaced = replacer.strings_replace(
            example.xquery, example.ytext, seed_from_x_val(example.xquery))
        ast = parser.create_parse_tree(y_replaced, example.ytype)
        _, token_metadata = tokenizer.tokenize(x_replaced)
        copy_ast = copy_tools.make_copy_version_of_tree(
            ast, unparser, token_metadata)
        # TODO Think about whether feeding in the raw x is good idea.
        # will change once have replacer sampling
        latents = model.get_latent_select_states(example.xquery, copy_ast)
        nodes = list(copy_ast.depth_first_iter())
        #print("LATENTS", latents)
        for i, l in enumerate(latents):
            dfs_depth = i * 2
            n = nodes[dfs_depth].cur_node
            assert isinstance(n, ObjectChoiceNode)
            c = l.detach()
            assert not c.requires_grad
            latent_store.set_latent_for_example(c, n.type_to_choose.ind,
                                                example.id, dfs_depth)
    model.set_in_train_mode()
Example #5
0
def make_x_vocab_from_examples(
    example_store: ExamplesStore,
    x_tokenizer: Tokenizer,
    replacers: Replacer,
    x_vocab_builder: VocabBuilder = None,
    min_freq: int = 1,
    train_only: bool = True,
    replace_samples: int = 1,
    extract_lambda = get_text_from_tok
) -> Vocab:
    """
    Args:
        example_store: The example store to generate x values from
        x_tokenizer: The tokenizer to generate the tokens we will put in the x vocab
        x_vocab_builder: A builder the x_tokenizer

    Returns:
        The x vocab
    """
    if x_vocab_builder is None:
        x_vocab_builder = CounterVocabBuilder(min_freq=min_freq)

    for _ in range(replace_samples):
        for example in example_store.get_all_x_values():
            if example.split != DataSplits.TRAIN and train_only:
                continue
            xstr, _ = replacers.strings_replace(example.x_text, "")
            x_vocab_builder.add_sequence(
                map(extract_lambda, x_tokenizer.tokenize(xstr)[0])
            )
    return x_vocab_builder.produce_vocab()
Example #6
0
def make_vocab_from_example_store(
    exampe_store: ExamplesStore,
    x_tokenizer: Tokenizer,
    y_tokenizer: Tokenizer,
    x_vocab_builder: VocabBuilder = None,
    y_vocab_builder: VocabBuilder = None
) -> typing.Tuple[Vocab, Vocab]:
    """Creates an x and y vocab based off all examples in an example store.

    Args:
        exampe_store: The example store to build from
        x_tokenizer: The tokenizer to use for x queries
        y_tokenizer: The tokenizer to use for y queries
        x_vocab_builder: The builder for the kind of x vocab we want to produce.
            If None, it just picks a reasonable default.
        y_vocab_builder: The builder for the kind of y vocab we want to produce.
            If None, it just picks a reasonable default.
    Returns:
        Tuple of the new (x vocab, y vocab).
    """
    if x_vocab_builder is None:
        x_vocab_builder = CounterVocabBuilder(min_freq=1)
    if y_vocab_builder is None:
        y_vocab_builder = CounterVocabBuilder()

    already_done_ys = set()
    parser = StringParser(exampe_store.type_context)
    for example in exampe_store.get_all_x_values():
        x_vocab_builder.add_sequence(x_tokenizer.tokenize(example.xquery)[0])
        if example.ytext not in already_done_ys:
            ast = parser.create_parse_tree(example.ytext, example.ytype)
            y_tokens, _ = y_tokenizer.tokenize(ast)
            y_vocab_builder.add_sequence(y_tokens)
    return x_vocab_builder.produce_vocab(), y_vocab_builder.produce_vocab()
Example #7
0
def iterate_data_pairs(
    example_store: ExamplesStore, replacers: Replacer,
    string_parser: StringParser, str_tokenizer: StringTokenizer,
    unparser: AstUnparser, filter_splits: Optional[Tuple[DataSplits]]
) -> Generator[Tuple[XValue, str, AstObjectChoiceSet, ObjectChoiceNode,
                     Set[str], ReplacementSampling], None, None]:
    """Will yield one epoch of examples as a tuple of the example and the
    Ast set that represents all valid y_values for that example"""
    type_context = example_store.type_context
    all_ex_list = list(example_store.get_all_x_values(filter_splits))
    random.shuffle(all_ex_list)
    for example in all_ex_list:  #self.example_store.get_all_x_values(splits):
        all_y_examples = example_store.get_y_values_for_y_set(example.y_set_id)
        y_type = type_context.get_type_by_name(all_y_examples[0].y_type)
        replacement_sample = replacers.create_replace_sampling(example.x_text)
        this_example_replaced_x = replacement_sample.replace_x(example.x_text)
        this_x_tokens, this_x_metadata = str_tokenizer.tokenize(
            this_example_replaced_x)
        y_ast_set, y_texts, teacher_force_path_ast = make_y_ast_set(
            y_type, all_y_examples, replacement_sample, string_parser,
            this_x_metadata, unparser)
        yield (example, this_example_replaced_x, y_ast_set,
               teacher_force_path_ast, y_texts, replacement_sample)
Example #8
0
def post_process_explanations(
    retr_explans: Tuple[ExampleRetrieveExplanation, ...],
    example_store: ExamplesStore,
    outputted_ast: ObjectChoiceNode,
    outputted_unparser: UnparseResult
) -> List[ExampleExplanPostProcessedOutput]:
    narrowed_down_examples = _narrow_down_examples(retr_explans)
    out = []
    for example_id, use_dfs_ids in narrowed_down_examples:
        actual_example = example_store.get_example_by_id(example_id)
        intervals = _get_unparse_intervals_of_inds(use_dfs_ids, outputted_ast, outputted_unparser)
        if len(intervals) == 0:
            continue
        out.append(ExampleExplanPostProcessedOutput(
            example_str=actual_example.xquery,
            example_cmd=actual_example.ytext,
            input_str_intervals=_interval_tree_to_tuples(intervals)
        ))
    out.sort(key=lambda v: v.input_str_intervals[0][0])
    return out
Example #9
0
def get_y_ast_sets(
        xs: List[str], yids: List[int], yhashes: List[str],
        rsamples: List[ReplacementSampling], string_tokenizer: StringTokenizer,
        index: ExamplesStore
) -> Generator[Tuple[AstSet, Set[str]], None, None]:
    type_context = index.type_context
    string_parser = StringParser(type_context)
    unparser = AstUnparser(type_context, string_tokenizer)
    verify_same_hashes(index, yids, yhashes)
    assert len(xs) == len(yids) == len(rsamples)
    for x, yid, rsample in zip(xs, yids, rsamples):
        yvalues = index.get_y_values_for_y_set(yid)
        this_x_toks, this_x_metadata = string_tokenizer.tokenize(x)
        y_ast_set, y_texts, teacher_force_path_ast = make_y_ast_set(
            y_type=type_context.get_type_by_name(yvalues[0].y_type),
            all_y_examples=yvalues,
            replacement_sample=rsample,
            string_parser=string_parser,
            this_x_metadata=this_x_metadata,
            unparser=unparser)
        yield y_ast_set, y_texts
Example #10
0
def make_latent_store_from_examples(
    examples: ExamplesStore,
    latent_size: int,
    replacer: Replacer,
    parser: StringParser,
    unparser: AstUnparser,
    splits=(DataSplits.TRAIN, )) -> LatentStore:
    builder = TorchLatentStoreBuilder(examples.type_context.get_type_count(),
                                      latent_size)
    for example in examples.get_all_x_values(splits):
        if replacer is not None:
            x = example.xquery
            x, y = replacer.strings_replace(x, example.ytext,
                                            seed_from_x_val(x))
        else:
            x, y = example.xquery, example.ytext
        ast = parser.create_parse_tree(y, example.ytype)
        _, token_metadata = unparser.input_str_tokenizer.tokenize(x)
        ast_with_copies = copy_tools.make_copy_version_of_tree(
            ast, unparser, token_metadata)
        builder.add_example(example.id, ast_with_copies)
    return builder.produce_result()
Example #11
0
def make_vocab_from_example_store_and_type_context(
    example_store: ExamplesStore,
    x_tokenizer: Tokenizer,
    x_vocab_builder: VocabBuilder = None
) -> typing.Tuple[Vocab, TypeContextWrapperVocab]:
    """
    Like the above method, except this one doesn't do special tokenizes to get
    a y vocab and instead just grabs everything from the example_store's type
    context.
    Args:
        example_store: The example store to generate x values from
        x_tokenizer: The tokenizer to generate the tokens we will put in the x vocab
        x_vocab_builder: A builder the x_tokenizer

    Returns:
        The x and y vocab
    """
    if x_vocab_builder is None:
        x_vocab_builder = CounterVocabBuilder(min_freq=1)

    for example in example_store.get_all_x_values():
        x_vocab_builder.add_sequence(x_tokenizer.tokenize(example.xquery)[0])
    y_vocab = TypeContextWrapperVocab(example_store.type_context)
    return x_vocab_builder.produce_vocab(), y_vocab
Example #12
0
 def __init__(self, store_to_cache: ExamplesStore):
     self.store_to_cache = store_to_cache
     self.x_vals_cache = [[]
                          for _ in range(store_to_cache.get_num_x_values())]
Example #13
0
def _example_inds_to_examples(example_inds: Sequence[int], example_store: ExamplesStore):
    return [example_store.get_example_by_id(eid) for eid in example_inds]