def _preload_data(self, file_path, label_file_path, batch_size=1000, max_len=None): seqs = seq_all(file_path) codes = seqs["codes"] label_index_df = pd.read_pickle(label_file_path) piece_labels = label_index_df["label"] indexes = label_index_df["index"] sample_ids = label_index_df["sample_id"] code_pieces = fetch_code_pieces(codes, sample_ids, indexes) ast_label_seqs = seq_from_code_ast(seqs) sub_codes, labels = ast_label_seqs["sub_code_pieces"], ast_label_seqs[ self.hparams["snake_params"].label_type] # TODO try different sample strategy for sub_codes. # code_pieces, piece_labels = sq.concatenate(sub_codes), sq.concatenate(labels) # code_pieces = sq.smap(utf8decode, code_pieces) tok_codes = sq.smap(tokenize_plus(self.tokenizer, max_len, True), code_pieces) tok_piece_labels = sq.smap(label_tokenize(self.label_tokenizer), piece_labels) return sq.collate([tok_codes, tok_piece_labels])
def _preload_data(self, file_path, batch_size=1000, max_len=None): seqs = seq_all(file_path) codes, docs = seqs["codes"], seqs["docs"] tok_codes = sq.smap(tokenize_plus(self.tokenizer, max_len, True), codes) tok_docs = sq.smap(tokenize_plus(self.tokenizer, max_len, True), docs) return sq.collate([tok_codes, tok_docs])
def test_throughput(): def proc(x): time.sleep(proc.delay) return x proc.delay = 0.01 arr = list(range(100)) monitored_arr = monitor_throughput(smap(proc, arr)) x = list(monitored_arr) assert x == arr assert monitored_arr.throughput() - 100 < 1 monitored_arr.reset() with pytest.raises(RuntimeError): monitored_arr.read_delay() with pytest.raises(RuntimeError): monitored_arr.throughput() proc.delay = 0.02 arr = list(range(100)) monitored_arr = monitor_throughput(smap(proc, arr)) x = [monitored_arr[i] for i in range(len(monitored_arr))] assert x == arr assert monitored_arr.read_delay() - 0.02 < 0.002
def seq_from_code_ast(_seq_dict): _code_bytes = _seq_dict["code_bytes"] #FIXME: for php it will ALMOST contain only 'program' and 'text' (even on playground). # fix it by wrapping code bytes with <?php ... ?> # and check if there's any exceptions (label counts shows a different view). # java need a extra class Test{ ... } wrapper, otherwise it will not compile right. _asts = _seq_dict["asts"] sub_code_pieces = sq.smap(_sub_code_pieces, _asts, _code_bytes) sub_code_indexes = sq.smap(_sub_code_indexes, _asts) # sub_asts = sq.smap(_sub_labels, _asts) sub_labels = sq.smap(_sub_labels, _asts) type_label = sq.smap( lambda lbs: [[x[1] for x in labels] for labels in lbs], sub_labels ) combined_label = sq.smap( lambda lbs: [[f"{x[0]}-{x[1]}" for x in labels] for labels in lbs], sub_labels ) _dict_all = locals() _dict_return = {k: v for k, v in _dict_all.items() if not k.startswith("_")} # print(_dict_return.keys()) return _dict_return
def test_prefetch_timing(method): def f1(x): sleep(.02 + 0.01 * (random.random() - .5)) return x arr = list(range(100)) y = smap(f1, arr) y = prefetch(y, nworkers=2, max_cached=20, method=method, timeout=1) t1 = time() z = list(y) t2 = time() assert z == arr duration = t2 - t1 print("test_prefetch_timing({}):1 {}".format(method, duration)) assert duration < 1.3 arr = list(range(200)) y = smap(f1, arr) y = prefetch(y, nworkers=2, max_cached=20, method=method, timeout=1, anticipate=lambda i: i + 2) t1 = time() z = [y[i] for i in range(0, len(y), 2)] t2 = time() assert z == arr[::2] duration = t2 - t1 print("test_prefetch_timing({}):2 {}".format(method, duration)) assert duration < 1.3
def test_prefetch(method): def f1(x): sleep(0.005 * (1 + random.random())) return x if method == "process": start_hook = random.seed else: start_hook = None arr = list(range(300)) y = smap(f1, arr) y = prefetch(y, nworkers=4, max_buffered=10, method=method, timeout=1, start_hook=start_hook) # check if workers are properly restarted when asleep i = 0 n_wakeups = 3 for _ in range(500): if n_wakeups > 0 and random.random() < 0.005: sleep(1.1) # will let worker go to sleep n_wakeups -= 1 value = y[i] assert value == arr[i] if random.random() < 0.05: i = random.randrange(0, len(arr)) else: i = (i + 1) % len(arr) # helps with coverage y.async_seq._finalize(y.async_seq) # overly large buffer arr = list(range(10)) y = smap(f1, arr) y = prefetch(y, nworkers=4, max_buffered=50, method=method, timeout=1) assert list(y) == arr # anticipate method arr = list(range(200)) y = smap(f1, arr) y = prefetch(y, nworkers=2, max_buffered=20, method=method, timeout=1, anticipate=lambda i: i + 2) z = [y[i] for i in range(0, len(y), 2)] assert z == arr[::2]
def _preload_data(self, file_path, batch_size=1000, max_len=None): seqs = seq_all(file_path) codes = seqs["codes"] docs = seqs["docs"] # TODO try different sample strategy for sub_codes. tok_both = sq.smap(tokenize_pair_plus(self.tokenizer, max_len, True), docs, codes) # FIXME: here <PAD> included for random mask. tok_only = sq.smap(lambda x: x["input_ids"], tok_both) tok_piece_labels = sq.smap( random_mask(list(self.tokenizer.get_vocab().values())), tok_only) return sq.collate([tok_both, tok_piece_labels])
def test_prefetch_errors(method): class CustomError(Exception): pass def f1(x): if x is None: raise CustomError() else: return x arr1 = [1, 2, 3, None] arr2 = smap(f1, arr1) y = prefetch(arr2, nworkers=2, max_cached=2, method=method) for i in range(3): assert y[i] == arr1[i] with pytest.raises(PrefetchException): a = y[3] del a def f2(x): if x is None: raise ValueError("blablabla") else: return x # helps with coverage y._finalize(y) arr2 = smap(f2, arr1) y = prefetch(arr2, nworkers=2, max_cached=2, method=method) for i in range(3): assert y[i] == arr1[i] try: a = y[3] del a except Exception as e: assert isinstance(e, PrefetchException) assert isinstance(e.__cause__, ValueError) else: assert False assert y[0] == 1 assert y[1] == 2 # helps with coverage y._finalize(y)
def test_prefetch(method): def f1(x): sleep(0.005 * (1 + random.random())) return x if method == "process": start_hook = None else: start_hook = None arr = list(range(300)) y = smap(f1, arr) y = prefetch(y, nworkers=4, max_cached=10, method=method, timeout=1, start_hook=start_hook) # arr = arr[3:-1:2] # y = y[3:-1:2] i = 0 n_wakeups = 3 for _ in range(500): if n_wakeups > 0 and random.random() < 0.005: sleep(1.1) # will let worker go to sleep n_wakeups -= 1 assert y[i] == arr[i] if random.random() < 0.05: i = random.randrange(0, len(arr)) else: i = (i + 1) % len(arr) # helps with coverage y._finalize(y)
def test_debug(): arr = list(range(100)) def do(i, v): del i, v do.i += 1 do.i = 0 debugged_arr = debug(arr, do) assert list(debugged_arr) == arr assert do.i == 100 assert [debugged_arr[i] for i in range(len(debugged_arr))] == arr assert do.i == 200 do.i = 0 debugged_arr = debug(arr, do, max_calls=3) assert list(debugged_arr) == arr assert do.i == 3 def proc(x): time.sleep(0.01) return x do.i = 0 debugged_arr = debug(smap(proc, arr), do, max_rate=10) assert list(debugged_arr) == arr assert do.i == 10
def test_cached(): def f(x): return x cache_size = 3 arr = [random.random() for _ in range(25)] z = add_cache(arr, cache_size) assert list(z) == arr assert list(z[10:]) == arr[10:] assert [z[i] for i in range(10)] == arr[:10] z[:10] = list(range(0, -10, -1)) assert list(z[10:]) == arr[10:] assert list(z[:10]) == list(range(0, -10, -1)) y = smap(f, arr) z = add_cache(y, cache_size) t1 = time() for i in range(len(arr)): assert z[i] == arr[i] for j in range(max(0, i - cache_size + 1), i + 1): assert z[j] == arr[j] t2 = time() duration = t2 - t1 assert duration < .28
def test_prefetch_errors(method, evaluation, picklable_err): class CustomError(Exception): pass def f1(x): if x is None: raise ValueError("blablabla") if picklable_err else CustomError() else: return x arr1 = [1, 2, 3, None] arr2 = smap(f1, arr1) y = prefetch(arr2, nworkers=2, max_buffered=2, method=method) seterr(evaluation) if (method == "process" and not picklable_err) or evaluation == "wrap": error_t = EvaluationError else: error_t = ValueError if picklable_err else CustomError for i in range(3): assert y[i] == arr1[i] with pytest.raises(error_t): a = y[3] del a
def make_sequence(self): """Build a sequence that looks like a dataloader when iterated over.""" # shuffling if self.batch_sampler: batch_indices = list(self.batch_sampler) out = seqtools.smap(lambda bi: [self.dataset[i] for i in bi], batch_indices) elif self.sampler: shuffle_indices = list(self.sampler) out = seqtools.gather(self.dataset, shuffle_indices) elif self.shuffle: shuffle_indices = np.random.permutation(len(self.dataset)) out = seqtools.gather(self.dataset, shuffle_indices) else: out = self.dataset # batch if not self.batch_sampler and self.batch_size is not None: out = seqtools.batch(out, k=self.batch_size, drop_last=self.drop_last, collate_fn=self.collate_fn) elif self.batch_sampler: out = seqtools.smap(self.collate_fn, out) # prefetch if self.num_workers > 0: out = seqtools.prefetch(out, max_buffered=self.num_workers * self.prefetch_factor, nworkers=self.num_workers, method='process', start_hook=self.worker_init_fn, shm_size=self.shm_size) # convert into tensors out = seqtools.smap(into_tensors, out) # pin memory if self.pin_memory: out = seqtools.smap(pin_tensors_memory, out) out = seqtools.prefetch(out, nworkers=1, method='thread', max_buffered=1) return out
def transform_frames(frame_seq, t: Transformation): # shorthand notations duration = len(frame_seq) sx, sy = t.xscale, t.yscale rx, ry = t.ref2d rz = t.ref3d[2] # generate affine transformation matrix # triangles_src = np.array([[rx, rx + 1, rx], [ry, ry, ry + 1]], dtype=np.float32) triangles_src = np.array([[rx, ry], [rx + 10, ry], [rx, ry + 10]], dtype=np.float32) triangles_dst = np.copy(triangles_src) x = triangles_dst[:, 0] y = triangles_dst[:, 1] z_corrections = 1 + t.zshift / (rz + .0001) x[...] = (x - rx) / z_corrections + rx y[...] = (y - ry) / z_corrections + ry x[...] = (x - rx) * sx + rx y[...] = (y - ry) * sy + ry x[...] = rx + np.cos(t.tilt) * (x - rx) - np.sin(t.tilt) * (y - ry) y[...] = ry + np.sin(t.tilt) * (x - rx) + np.cos(t.tilt) * (y - ry) tmatrix = cv2.getAffineTransform(triangles_src, triangles_dst) # affine frame-wise transformations output = seqtools.smap(lambda f: cv2.warpAffine(f, tmatrix, (640, 480)), frame_seq) # fliplr if t.fliplr: output = seqtools.smap(np.fliplr, output) # time scale if t.tscale != 1: output_duration = transform_durations(duration, t) indices = np.round(np.linspace(0, duration - 1, output_duration)).astype(np.int) output = seqtools.gather(output, indices) if t.tscale > 1: output = seqtools.add_cache(output, cache_size=1) return output
def test_smap_exceptions(): def do(x): del x raise CustomException data = [random.random() for _ in range(100)] m = smap(do, data) with pytest.raises(CustomException): print(m[0]) with pytest.raises(CustomException): next(iter(m)) with pytest.raises(TypeError): smap(None, data) with pytest.raises(ValueError): smap(do)
def target(): arr = np.random.rand(1000, 10) y = smap(f1, arr) y = prefetch(y, method=method, max_buffered=40, nworkers=4, start_hook=init_fn) for i in range(0, 1000): a = y[i]
def test_prefetch_timings(prefetch_kwargs): def f1(x): sleep(0.005 * (1 + random.random())) return x start_hook = random.seed arr = np.random.rand(100, 10) y = smap(f1, arr) y = prefetch(y, nworkers=4, max_buffered=10, start_hook=start_hook, **prefetch_kwargs) y = [y_.copy() for y_ in y] # copy needed to release buffers when shm_size>0 assert_array_equal(np.stack(y), arr) # overly large buffer arr = np.random.rand(10, 10) y = smap(f1, arr) y = prefetch(y, nworkers=4, max_buffered=50, **prefetch_kwargs) y = [y_.copy() for y_ in y] assert_array_equal(np.stack(y), arr) # multiple restarts arr = np.random.rand(100, 10) y = smap(f1, arr) y = prefetch(y, nworkers=4, max_buffered=10, **prefetch_kwargs) for _ in range(10): n = np.random.randint(0, 99) for i in range(n): assert_array_equal(y[i], arr[i]) # starvation arr = np.random.rand(100, 10) y = prefetch(arr, nworkers=2, max_buffered=10, **prefetch_kwargs) y[0] sleep(2) for i in range(1, 100): assert_array_equal(y[i], arr[i])
def test_smap_exceptions(evaluation): def do(x): del x raise CustomException data = [random.random() for _ in range(100)] m = smap(do, data) seterr(evaluation) error_t = EvaluationError if evaluation == "wrap" else CustomException with pytest.raises(error_t): print(m[0]) with pytest.raises(error_t): next(iter(m)) with pytest.raises(TypeError): smap(None, data) with pytest.raises(ValueError): smap(do)
def reload(): global train_subset, val_subset, test_subset, \ durations, gloss_seqs, pose2d_seqs, pose3d_seqs, frame_seqs with open(os.path.join(cachedir, "data.pkl"), 'rb') as f: durations, gloss_seqs, rec_mapping, transformations, \ train_subset, val_subset, test_subset = pkl.load(f) segments = np.stack( [np.cumsum(durations) - durations, np.cumsum(durations)], axis=1) pose2d_seqs = seqtools.split( np.load(os.path.join(cachedir, "pose2d_seqs.npy"), mmap_mode='r'), segments) pose3d_seqs = seqtools.split( np.load(os.path.join(cachedir, "pose3d_seqs.npy"), mmap_mode='r'), segments) frame_seqs = seqtools.smap(lambda r: dataset.bgr_frames(r), rec_mapping) frame_seqs = seqtools.smap(transform_frames, frame_seqs, transformations)
def test_smap_basics(): n = 100 data = [random.random() for _ in range(n)] def do(x): do.call_cnt += 1 return x + 1 do.call_cnt = 0 # indexing result = smap(do, data) assert len(result) == len(data) assert do.call_cnt == 0 assert list(result) == [x + 1 for x in data] assert do.call_cnt == n assert [result[i] for i in range(len(result))] == [x + 1 for x in data] assert list(result[:]) == [x + 1 for x in data]
def test_prefetch_timing(method): def f1(x): sleep(.02 + 0.01 * (random.random() - .5)) return x arr = list(range(420)) y = smap(f1, arr) y = prefetch(y, nworkers=2, max_buffered=20, method=method, timeout=1) for i in range(20): y[i] # consume first items to eliminate worker startup time t1 = time() for i in range(20, 420): y[i] t2 = time() duration = t2 - t1 print("test_prefetch_timing({}) {:.2f}s".format(method, duration)) assert duration < 4.5
def test_prefetch_throughput(prefetch_kwargs): # pragma: no cover def f1(x): sleep(.02 + 0.01 * (random.random() - .5)) return x arr = np.random.rand(420, 10) y = smap(f1, arr) y = prefetch(y, nworkers=2, max_buffered=40, **prefetch_kwargs) for i in range(20): y[i] # consume first items to eliminate worker startup time t1 = time() for i in range(20, 420): y[i] t2 = time() duration = t2 - t1 print("test_prefetch_timing: {:.2f}s".format(duration)) assert duration < 4.5
def test_cached_timing(): def f(x): sleep(.01) return x cache_size = 3 arr = [random.random() for _ in range(100)] y = smap(f, arr) z = add_cache(y, cache_size) t1 = time() for i in range(len(arr)): assert z[i] == arr[i] for j in range(max(0, i - cache_size + 1), i + 1): assert z[j] == arr[j] t2 = time() duration = t2 - t1 print("test_cached_timing {:.2f}s".format(duration)) assert duration < 1.2
def test_prefetch_errors(error_mode, prefetch_kwargs, picklable_err): class CustomError(Exception): pass def f1(x): if x is None: raise ValueError("blablabla") if picklable_err else CustomError() else: return x arr1 = [np.random.rand(10), np.random.rand(10), np.random.rand(10), None] arr2 = smap(f1, arr1) y = prefetch(arr2, nworkers=2, max_buffered=4, **prefetch_kwargs) seterr(error_mode) if (prefetch_kwargs['method'] != "thread" and not picklable_err) or error_mode == "wrap": error_t = EvaluationError else: error_t = ValueError if picklable_err else CustomError for i in range(3): assert_array_equal(y[i], arr1[i]) try: a = y[3] except Exception as e: assert type(e) == error_t if (prefetch_kwargs['method'] == "process") and error_mode == "passthrough": class CustomObject: # unpicklable object pass arr1 = [np.random.rand(10), CustomObject(), np.random.rand(10)] y = prefetch(arr1, nworkers=2, max_buffered=4, **prefetch_kwargs) with pytest.raises(ValueError): y[1]
def seq_all(_input_path): _sample_df = pd.read_pickle(_input_path) # index = _sample_df.index codes = _sample_df["code"] docs = _sample_df["docstring"] code_bytes = sq.smap(utf8encode, codes) languages = _sample_df["language"] parsers = sq.smap(get_parser, languages) asts = sq.smap(Parser.parse, parsers, code_bytes) code_tokens = _sample_df["code_tokens"] doc_tokens = _sample_df["docstring_tokens"] code_split_identifiers = sq.smap(indentifier_split, code_tokens) code_tokens_with_identifier_split = sq.smap(pd.Series.explode, code_split_identifiers) doc_split_identifiers = sq.smap(indentifier_split, doc_tokens) doc_tokens_with_identifier_split = sq.smap(pd.Series.explode, doc_split_identifiers) _dict_all = locals() _dict_return = { k: v for k, v in _dict_all.items() if not k.startswith("_") } # print(_dict_return.keys()) return _dict_return
import pandas as pd import dataset_seq as ds import ast_label_pretrain as ap import seqtools as sq from itertools import chain from utils import fetch_snakemake_from_latest_run try: snakemake except NameError: snakemake = fetch_snakemake_from_latest_run(__file__) seqs_all = ds.seq_all(snakemake.input[0]) seqs_labels = ap.seq_from_code_ast(seqs_all) sub_code_indexes = seqs_labels["sub_code_indexes"] type_label = seqs_labels[snakemake.params.label_type] sample_ids = sq.smap( lambda samp_list, samp_index: [samp_index] * len(samp_list), sub_code_indexes, range(0, len(sub_code_indexes))) df = pd.DataFrame({ "index": chain.from_iterable(sub_code_indexes), "sample_id": chain.from_iterable(sample_ids), "label": chain.from_iterable(type_label), }) df = df[df["index"].apply(lambda x: x[0] != x[1])].reset_index(drop=True) df.to_pickle(snakemake.output[0])
def fetch_code_pieces(codes, sample_ids, indexes): piece_full_code = sq.smap(lambda x:codes[x], sample_ids) code_pieces = sq.smap(_fetch_sub_code, indexes, piece_full_code) return code_pieces
from __future__ import print_function import time import seqtools files = [ 'file1', 'file2', 'file3', 'file4', 'file5', 'file6', 'file7', 'file8', 'file9', 'file10' ] def load(some_file): time.sleep(.1) return list(range(10) if some_file == 'file10' else range(200)) loaded_files = seqtools.smap(load, files) loaded_files = seqtools.add_cache(loaded_files, 2) all_samples = seqtools.unbatch(loaded_files, 200, 10) def preprocess(x): t = time.clock() while time.clock() - t < 0.005: pass # busy waiting return x preprocessed_samples = seqtools.smap(preprocess, all_samples) minibatches = seqtools.batch(preprocessed_samples, 64, collate_fn=list) t1 = time.time()
def prepare(): global train_subset, val_subset, test_subset, \ durations, gloss_seqs, pose2d_seqs, pose3d_seqs # Create temporary directory if not os.path.exists(cachedir): os.mkdir(cachedir) # Load data train_subset, val_subset, test_subset = dataset.default_splits() pose2d_seqs = [dataset.positions(i) for i in range(len(dataset))] pose3d_seqs = [dataset.positions_3d(i) for i in range(len(dataset))] # Eliminate strange gloss annotations gloss_seqs_train = [dataset.glosses(r) for r in train_subset] rejected = set() for r, gseq in zip(train_subset, gloss_seqs_train): for i in range(len(gseq) - 1): if gseq[i + 1, 1] - gseq[i, 2] < 0: rejected.add(r) train_subset = np.setdiff1d(train_subset, rejected) if len(rejected) > 0: logging.warning( "Eliminated sequences with invalid glosses: {}".format(rejected)) # Interpolate missing poses and eliminate deteriorated training sequences invalid_masks = seqtools.smap(detect_invalid_pts, pose2d_seqs) pose2d_seqs = seqtools.smap(interpolate_positions, pose2d_seqs, invalid_masks) pose3d_seqs = seqtools.smap(interpolate_positions, pose3d_seqs, invalid_masks) rejected = np.where( [np.mean(im[:, important_joints]) > .15 for im in invalid_masks])[0] train_subset = np.setdiff1d(train_subset, rejected) if len(rejected) > 0: logging.warning( "eliminated {} sequences with missing positions".format( len(rejected))) # Default preprocessing ref2d = seqtools.add_cache(seqtools.smap(get_ref_pts, pose2d_seqs), cache_size=1) ref3d = seqtools.add_cache(seqtools.smap(get_ref_pts, pose3d_seqs), cache_size=1) transformations = np.rec.array( [(r2, r3, False, 0, tgt_dist - r3[2], 1., 1., 1., 1.) for r2, r3 in zip(ref2d, ref3d)], dtype=TransformationType) # Precompute transformations for augmentation of the training set original_train_subset = train_subset rec_mapping = np.arange(len(dataset)) for _ in range(5 - 1): offset = len(rec_mapping) new_subset = np.arange(offset, offset + len(original_train_subset)) newt = np.repeat(transformations[0], len(new_subset), axis=0).view(np.recarray) newt.fliplr = uniform(size=len(newt)) < 0.15 newt.tilt += uniform(-7, 7, size=len(newt)) * np.pi / 180 newt.xscale += uniform(.85, 1.15, size=len(newt)) newt.yscale += uniform(.85, 1.15, size=len(newt)) newt.zscale += uniform(.85, 1.15, size=len(newt)) newt.tscale += uniform(.85, 1.15, size=len(newt)) rec_mapping = np.concatenate([rec_mapping, original_train_subset]) transformations = np.concatenate([transformations, newt]).view(np.recarray) train_subset = np.concatenate([train_subset, new_subset]) # Apply transformations (if they are cheap to compute) durations = np.array([ transform_durations(dataset.durations(r), t) for r, t in zip(rec_mapping, transformations) ]) gloss_seqs = [ transform_glosses(dataset.glosses(r), dataset.durations(r), t) for r, t in zip(rec_mapping, transformations) ] pose2d_seqs = seqtools.gather(pose2d_seqs, rec_mapping) pose2d_seqs = seqtools.smap( partial(transform_pose2d, flip_mapping=flip_mapping, frame_width=640), pose2d_seqs, transformations) pose3d_seqs = seqtools.gather(pose3d_seqs, rec_mapping) pose3d_seqs = seqtools.smap( partial(transform_pose3d, flip_mapping=flip_mapping), pose3d_seqs, transformations) # Export np.save(os.path.join(cachedir, "pose2d_seqs.npy"), seqtools.concatenate(pose2d_seqs)) np.save(os.path.join(cachedir, "pose3d_seqs.npy"), seqtools.concatenate(pose3d_seqs)) with open(os.path.join(cachedir, "data.pkl"), 'wb') as f: pkl.dump((durations, gloss_seqs, rec_mapping, transformations, train_subset, val_subset, test_subset), f)
return query_embeddings, code_embeddings, losses, reiprocal_rank, mrr #train, test and evaluation. ds_train = seq_all(input.train_data) code_tokens = ds_train["code_tokens_with_identifier_split"] doc_tokens = ds_train["doc_tokens_with_identifier_split"] code_token_counter = count_all(code_tokens) doc_token_counter = count_all(doc_token_counter) import seqtools as sq code_tokenized = sq.smap(tokenize(code_token_counter), code_tokens) doc_tokenized = sq.smap(tokenize(doc_token_counter), doc_tokens) code_pad = sq.smap(pad_to_1d(code_token_counter, 200), code_tokenized) doc_pad = sq.smap(pad_to_1d(doc_token_counter, 200), doc_tokenized) #batch and pad dataset. model = CodeQuerySoftmaxBertModel(code_token_counter, doc_token_counter) opt = torch.optim.Adam(model.parameters()) from tqdm import trange, tqdm train_data = sq.collate(doc_pad, code_pad)