def verify_same_hashes(index: ExamplesStore, yids, yhashes): for yid, yhash in zip(yids, yhashes): if index.get_y_set_hash(yid) != yhash: for yv in index.get_y_values_for_y_set(yid): print(yv) print(index.get_y_set_hash(yid)) raise ValueError("Not matching hash")
def full_ret_from_example_store(example_store: ExamplesStore, replacers: Replacer, pretrained_checkpoint: str, replace_samples: int = REPLACEMENT_SAMPLES, encoder_name="CM") -> FullRetModel: output_size = 200 # for glove the parseable vocab is huge. Narrow it down #query_vocab = make_x_vocab_from_examples(example_store, x_tokenizer, replacers, # min_freq=2, replace_samples=2) #print(f"vocab len {len(query_vocab)}") if encoder_name == "CM": (x_tokenizer, query_vocab), y_tokenizer = get_default_tokenizers( use_word_piece=True) embedder = make_default_query_encoder(x_tokenizer, query_vocab, output_size, pretrained_checkpoint) embedder.eval() elif encoder_name == "BERT": from ainix_kernel.models.EncoderDecoder.bertencoders import BertEncoder embedder = BertEncoder() else: raise ValueError(f"Unsupported encoder_name {encoder_name}") print( f"Number of embedder params: {get_number_of_model_parameters(embedder)}" ) parser = StringParser(example_store.type_context) unparser = AstUnparser(example_store.type_context, embedder.get_tokenizer()) summaries, example_refs, example_splits = [], [], [] nb_update_fn, finalize_nb_fn = get_nb_learner() with torch.no_grad(): for xval in tqdm(list(example_store.get_all_x_values())): # Get the most prefered y text all_y_examples = example_store.get_y_values_for_y_set( xval.y_set_id) most_preferable_y = all_y_examples[0] new_summary, new_example_ref = _preproc_example( xval, most_preferable_y, replacers, embedder, parser, unparser, nb_update_fn if xval.split == DataSplits.TRAIN else None, replace_samples=replace_samples) summaries.append(new_summary) example_refs.append(new_example_ref) example_splits.append(xval.split) return FullRetModel( embedder=embedder, summaries=torch.stack(summaries), dataset_splits=torch.tensor(example_splits), example_refs=np.array(example_refs), choice_models=None #finalize_nb_fn() )
def _load_single_example(example_dict: Dict, xtype: str, ytype: str, load_index: ExamplesStore, splits: DataSplitter): x = example_dict['x'] if not isinstance(x, list): x = [x] y = example_dict['y'] if not isinstance(y, list): y = [y] x = list(map(str, x)) y = list(map(str, y)) load_index.add_yset(x, y, xtype, ytype, default_preferences(len(y)), splits)
def update_latent_store_from_examples( model: 'StringTypeTranslateCF', latent_store: LatentStore, examples: ExamplesStore, replacer: Replacer, parser: StringParser, splits: Optional[Tuple[DataSplits]], unparser: AstUnparser, tokenizer: StringTokenizer): model.set_in_eval_mode() for example in examples.get_all_x_values(splits): # TODO multi sampling and average replacers x_replaced, y_replaced = replacer.strings_replace( example.xquery, example.ytext, seed_from_x_val(example.xquery)) ast = parser.create_parse_tree(y_replaced, example.ytype) _, token_metadata = tokenizer.tokenize(x_replaced) copy_ast = copy_tools.make_copy_version_of_tree( ast, unparser, token_metadata) # TODO Think about whether feeding in the raw x is good idea. # will change once have replacer sampling latents = model.get_latent_select_states(example.xquery, copy_ast) nodes = list(copy_ast.depth_first_iter()) #print("LATENTS", latents) for i, l in enumerate(latents): dfs_depth = i * 2 n = nodes[dfs_depth].cur_node assert isinstance(n, ObjectChoiceNode) c = l.detach() assert not c.requires_grad latent_store.set_latent_for_example(c, n.type_to_choose.ind, example.id, dfs_depth) model.set_in_train_mode()
def make_x_vocab_from_examples( example_store: ExamplesStore, x_tokenizer: Tokenizer, replacers: Replacer, x_vocab_builder: VocabBuilder = None, min_freq: int = 1, train_only: bool = True, replace_samples: int = 1, extract_lambda = get_text_from_tok ) -> Vocab: """ Args: example_store: The example store to generate x values from x_tokenizer: The tokenizer to generate the tokens we will put in the x vocab x_vocab_builder: A builder the x_tokenizer Returns: The x vocab """ if x_vocab_builder is None: x_vocab_builder = CounterVocabBuilder(min_freq=min_freq) for _ in range(replace_samples): for example in example_store.get_all_x_values(): if example.split != DataSplits.TRAIN and train_only: continue xstr, _ = replacers.strings_replace(example.x_text, "") x_vocab_builder.add_sequence( map(extract_lambda, x_tokenizer.tokenize(xstr)[0]) ) return x_vocab_builder.produce_vocab()
def make_vocab_from_example_store( exampe_store: ExamplesStore, x_tokenizer: Tokenizer, y_tokenizer: Tokenizer, x_vocab_builder: VocabBuilder = None, y_vocab_builder: VocabBuilder = None ) -> typing.Tuple[Vocab, Vocab]: """Creates an x and y vocab based off all examples in an example store. Args: exampe_store: The example store to build from x_tokenizer: The tokenizer to use for x queries y_tokenizer: The tokenizer to use for y queries x_vocab_builder: The builder for the kind of x vocab we want to produce. If None, it just picks a reasonable default. y_vocab_builder: The builder for the kind of y vocab we want to produce. If None, it just picks a reasonable default. Returns: Tuple of the new (x vocab, y vocab). """ if x_vocab_builder is None: x_vocab_builder = CounterVocabBuilder(min_freq=1) if y_vocab_builder is None: y_vocab_builder = CounterVocabBuilder() already_done_ys = set() parser = StringParser(exampe_store.type_context) for example in exampe_store.get_all_x_values(): x_vocab_builder.add_sequence(x_tokenizer.tokenize(example.xquery)[0]) if example.ytext not in already_done_ys: ast = parser.create_parse_tree(example.ytext, example.ytype) y_tokens, _ = y_tokenizer.tokenize(ast) y_vocab_builder.add_sequence(y_tokens) return x_vocab_builder.produce_vocab(), y_vocab_builder.produce_vocab()
def iterate_data_pairs( example_store: ExamplesStore, replacers: Replacer, string_parser: StringParser, str_tokenizer: StringTokenizer, unparser: AstUnparser, filter_splits: Optional[Tuple[DataSplits]] ) -> Generator[Tuple[XValue, str, AstObjectChoiceSet, ObjectChoiceNode, Set[str], ReplacementSampling], None, None]: """Will yield one epoch of examples as a tuple of the example and the Ast set that represents all valid y_values for that example""" type_context = example_store.type_context all_ex_list = list(example_store.get_all_x_values(filter_splits)) random.shuffle(all_ex_list) for example in all_ex_list: #self.example_store.get_all_x_values(splits): all_y_examples = example_store.get_y_values_for_y_set(example.y_set_id) y_type = type_context.get_type_by_name(all_y_examples[0].y_type) replacement_sample = replacers.create_replace_sampling(example.x_text) this_example_replaced_x = replacement_sample.replace_x(example.x_text) this_x_tokens, this_x_metadata = str_tokenizer.tokenize( this_example_replaced_x) y_ast_set, y_texts, teacher_force_path_ast = make_y_ast_set( y_type, all_y_examples, replacement_sample, string_parser, this_x_metadata, unparser) yield (example, this_example_replaced_x, y_ast_set, teacher_force_path_ast, y_texts, replacement_sample)
def post_process_explanations( retr_explans: Tuple[ExampleRetrieveExplanation, ...], example_store: ExamplesStore, outputted_ast: ObjectChoiceNode, outputted_unparser: UnparseResult ) -> List[ExampleExplanPostProcessedOutput]: narrowed_down_examples = _narrow_down_examples(retr_explans) out = [] for example_id, use_dfs_ids in narrowed_down_examples: actual_example = example_store.get_example_by_id(example_id) intervals = _get_unparse_intervals_of_inds(use_dfs_ids, outputted_ast, outputted_unparser) if len(intervals) == 0: continue out.append(ExampleExplanPostProcessedOutput( example_str=actual_example.xquery, example_cmd=actual_example.ytext, input_str_intervals=_interval_tree_to_tuples(intervals) )) out.sort(key=lambda v: v.input_str_intervals[0][0]) return out
def get_y_ast_sets( xs: List[str], yids: List[int], yhashes: List[str], rsamples: List[ReplacementSampling], string_tokenizer: StringTokenizer, index: ExamplesStore ) -> Generator[Tuple[AstSet, Set[str]], None, None]: type_context = index.type_context string_parser = StringParser(type_context) unparser = AstUnparser(type_context, string_tokenizer) verify_same_hashes(index, yids, yhashes) assert len(xs) == len(yids) == len(rsamples) for x, yid, rsample in zip(xs, yids, rsamples): yvalues = index.get_y_values_for_y_set(yid) this_x_toks, this_x_metadata = string_tokenizer.tokenize(x) y_ast_set, y_texts, teacher_force_path_ast = make_y_ast_set( y_type=type_context.get_type_by_name(yvalues[0].y_type), all_y_examples=yvalues, replacement_sample=rsample, string_parser=string_parser, this_x_metadata=this_x_metadata, unparser=unparser) yield y_ast_set, y_texts
def make_latent_store_from_examples( examples: ExamplesStore, latent_size: int, replacer: Replacer, parser: StringParser, unparser: AstUnparser, splits=(DataSplits.TRAIN, )) -> LatentStore: builder = TorchLatentStoreBuilder(examples.type_context.get_type_count(), latent_size) for example in examples.get_all_x_values(splits): if replacer is not None: x = example.xquery x, y = replacer.strings_replace(x, example.ytext, seed_from_x_val(x)) else: x, y = example.xquery, example.ytext ast = parser.create_parse_tree(y, example.ytype) _, token_metadata = unparser.input_str_tokenizer.tokenize(x) ast_with_copies = copy_tools.make_copy_version_of_tree( ast, unparser, token_metadata) builder.add_example(example.id, ast_with_copies) return builder.produce_result()
def make_vocab_from_example_store_and_type_context( example_store: ExamplesStore, x_tokenizer: Tokenizer, x_vocab_builder: VocabBuilder = None ) -> typing.Tuple[Vocab, TypeContextWrapperVocab]: """ Like the above method, except this one doesn't do special tokenizes to get a y vocab and instead just grabs everything from the example_store's type context. Args: example_store: The example store to generate x values from x_tokenizer: The tokenizer to generate the tokens we will put in the x vocab x_vocab_builder: A builder the x_tokenizer Returns: The x and y vocab """ if x_vocab_builder is None: x_vocab_builder = CounterVocabBuilder(min_freq=1) for example in example_store.get_all_x_values(): x_vocab_builder.add_sequence(x_tokenizer.tokenize(example.xquery)[0]) y_vocab = TypeContextWrapperVocab(example_store.type_context) return x_vocab_builder.produce_vocab(), y_vocab
def __init__(self, store_to_cache: ExamplesStore): self.store_to_cache = store_to_cache self.x_vals_cache = [[] for _ in range(store_to_cache.get_num_x_values())]
def _example_inds_to_examples(example_inds: Sequence[int], example_store: ExamplesStore): return [example_store.get_example_by_id(eid) for eid in example_inds]